1#include "models.h"
  2
  3ggml_cgraph * clip_graph_youtuvl::build() {
  4    GGML_ASSERT(model.class_embedding == nullptr);
  5    const int batch_size       = 1;
  6    const bool use_window_attn = !hparams.wa_layer_indexes.empty();
  7    const int n_pos            = n_patches;
  8    const int num_position_ids = n_pos * 4;
  9    const int m = 2;
 10    const int Wp = n_patches_x;
 11    const int Hp = n_patches_y;
 12    const int Hm = Hp / m;
 13    const int Wm = Wp / m;
 14    norm_type norm_t = NORM_TYPE_NORMAL;
 15
 16    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
 17
 18    ggml_tensor * inp = build_inp_raw();
 19
 20    // change conv3d to linear
 21    // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
 22    {
 23        inp = ggml_reshape_4d(
 24            ctx0, inp,
 25            Wm * m * patch_size, m * patch_size, Hm, 3);
 26        inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
 27        inp = ggml_cont_4d(
 28            ctx0, inp,
 29            m * patch_size * 3, Wm, m * patch_size, Hm);
 30
 31        inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
 32        inp = ggml_cont_4d(
 33            ctx0, inp,
 34            m * patch_size * 3, patch_size, m, Hm * Wm);
 35
 36        inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
 37        inp = ggml_cont_4d(
 38            ctx0, inp,
 39            patch_size, 3, patch_size, Hm * Wm * m * m);
 40
 41        inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
 42        inp = ggml_cont_3d(
 43            ctx0, inp,
 44            3*patch_size* patch_size,  Hm * Wm * m * m, 1);
 45    }
 46    inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
 47
 48    if (model.patch_bias) {
 49        inp = ggml_add(ctx0, inp, model.patch_bias);
 50    }
 51
 52    inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
 53
 54    ggml_tensor * inpL           = inp;
 55    ggml_tensor * window_mask    = nullptr;
 56    ggml_tensor * window_idx     = nullptr;
 57    ggml_tensor * inv_window_idx = nullptr;
 58
 59    ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
 60    ggml_set_name(positions, "positions");
 61    ggml_set_input(positions);
 62
 63    // pre-layernorm
 64    if (model.pre_ln_w) {
 65        inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
 66    }
 67    if (use_window_attn) {
 68        inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
 69        ggml_set_name(inv_window_idx, "inv_window_idx");
 70        ggml_set_input(inv_window_idx);
 71        // mask for window attention
 72        window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
 73        ggml_set_name(window_mask, "window_mask");
 74        ggml_set_input(window_mask);
 75
 76        // if flash attn is used, we need to pad the mask and cast to f16
 77        if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
 78            window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
 79        }
 80
 81        // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
 82        GGML_ASSERT(batch_size == 1);
 83        inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
 84        inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
 85        inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
 86    }
 87
 88    // loop over layers
 89    for (int il = 0; il < n_layer; il++) {
 90        const auto & layer = model.layers[il];
 91        const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
 92
 93        ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
 94
 95        // layernorm1
 96        cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
 97        // self-attention
 98        {
 99            ggml_tensor * Qcur = ggml_add(ctx0,
100                ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
101            ggml_tensor * Kcur = ggml_add(ctx0,
102                ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
103            ggml_tensor * Vcur = ggml_add(ctx0,
104                ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
105
106            Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
107            Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
108            Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
109
110            Qcur = ggml_rope_multi(
111                ctx0, Qcur, positions, nullptr,
112                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
113            Kcur = ggml_rope_multi(
114                ctx0, Kcur, positions, nullptr,
115                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
116
117            ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
118
119            cur = build_attn(layer.o_w, layer.o_b,
120                Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
121        }
122        // re-add the layer input, e.g., residual
123        cur = ggml_add(ctx0, cur, inpL);
124
125        inpL = cur; // inpL = residual, cur = hidden_states
126
127        // layernorm2
128        cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
129
130        // ffn
131        cur = build_ffn(cur,
132            layer.ff_up_w, layer.ff_up_b,
133            nullptr, nullptr,
134            layer.ff_down_w, layer.ff_down_b,
135            hparams.ffn_op, il);
136
137        // residual 2
138        cur = ggml_add(ctx0, inpL, cur);
139
140        inpL = cur;
141    }
142
143    ggml_tensor * embeddings = inpL;
144    if (use_window_attn) {
145        const int spatial_merge_unit = 4;
146        window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
147        ggml_set_name(window_idx, "window_idx");
148        ggml_set_input(window_idx);
149        GGML_ASSERT(batch_size == 1);
150        embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
151        embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
152        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
153        cb(embeddings, "window_order_restored", -1);
154    }
155
156    // post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
157    if (model.post_ln_w) {
158        embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
159    }
160
161    // Now apply merger (VLPatchMerger):
162    // 1. Apply RMS norm (ln_q in VLPatchMerger)
163    embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
164    cb(embeddings, "merger_normed", -1);
165
166    // 2. First reshape for spatial merge (merge 2x2 patches)
167    embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
168    cb(embeddings, "merger_reshaped", -1);
169
170    embeddings = build_ffn(embeddings,
171                    model.mm_0_w, model.mm_0_b,
172                    nullptr, nullptr,
173                    model.mm_1_w, model.mm_1_b,
174                    FFN_GELU,
175                    -1);
176    ggml_build_forward_expand(gf, embeddings);
177
178    return gf;
179}