1#include "models.h"
2
3ggml_cgraph * clip_graph_glm4v::build() {
4 GGML_ASSERT(model.patch_bias != nullptr);
5 GGML_ASSERT(model.position_embeddings != nullptr);
6 GGML_ASSERT(model.class_embedding == nullptr);
7
8 const int batch_size = 1;
9
10 norm_type norm_t = NORM_TYPE_RMS;
11
12 ggml_tensor * inp_raw = build_inp_raw();
13 ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
14
15 int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
16 ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4);
17 ggml_set_name(positions, "positions");
18 ggml_set_input(positions);
19
20 GGML_ASSERT(img.nx % (patch_size * 2) == 0);
21 GGML_ASSERT(img.ny % (patch_size * 2) == 0);
22
23 // second conv dimension
24 {
25 auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
26 inp = ggml_add(ctx0, inp, inp_1);
27
28 inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
29 inp = ggml_cont_4d(
30 ctx0, inp,
31 n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
32 inp = ggml_reshape_4d(
33 ctx0, inp,
34 n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
35 inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
36 inp = ggml_cont_3d(
37 ctx0, inp,
38 n_embd, n_patches_x * n_patches_y, batch_size);
39 }
40
41 // add patch bias
42 inp = ggml_add(ctx0, inp, model.patch_bias);
43 cb(inp, "patch_bias", -1);
44
45 // pos-conv norm
46 inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1);
47
48 // calculate absolute position embedding and apply
49 ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
50 learned_pos_embd = ggml_cont_4d(
51 ctx0, learned_pos_embd,
52 n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
53 learned_pos_embd = ggml_reshape_4d(
54 ctx0, learned_pos_embd,
55 n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
56 learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
57 learned_pos_embd = ggml_cont_3d(
58 ctx0, learned_pos_embd,
59 n_embd, n_patches_x * n_patches_y, batch_size);
60 cb(learned_pos_embd, "learned_pos_embd", -1);
61
62 auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
63 return ggml_rope_multi(
64 ctx0, cur, positions, nullptr,
65 d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
66 32768, hparams.rope_theta, 1, 0, 1, 32, 1);
67 };
68
69 ggml_tensor * cur = build_vit(
70 inp, n_patches,
71 norm_t,
72 hparams.ffn_op,
73 learned_pos_embd,
74 add_pos);
75
76 cb(cur, "vit_out", -1);
77 // cb(ggml_sum(ctx0, cur), "vit_out_sum", -1);
78
79 // GLM4V projector
80 // ref: https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/models/glm4v/modeling_glm4v.py#L116-L130
81
82 // patch merger (downsample)
83 {
84 int n_merge = hparams.n_merge;
85 GGML_ASSERT(n_merge > 0);
86
87 int n_token_out = n_patches / n_merge / n_merge;
88 cur = ggml_reshape_4d(ctx0, cur, n_embd, n_merge, n_merge, n_token_out);
89 cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); // [n_merge, n_merge, n_embd, n_token_out]
90 cur = ggml_conv_2d(ctx0, model.mm_patch_merger_w, cur, n_merge, n_merge, 0, 0, 1, 1);
91 cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], n_token_out); // [n_embd_out, n_token_out]
92
93 cur = ggml_add(ctx0, cur, model.mm_patch_merger_b);
94 }
95
96 // FC projector
97 {
98 cur = ggml_mul_mat(ctx0, model.projection, cur);
99 // default LayerNorm (post_projection_norm)
100 cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
101 cur = ggml_gelu_erf(ctx0, cur);
102 cb(cur, "after_fc_proj", -1);
103 }
104
105 // FFN projector
106 {
107 cur = build_ffn(cur,
108 model.mm_ffn_up_w, model.mm_ffn_up_b,
109 model.mm_ffn_gate_w, model.mm_ffn_gate_b,
110 model.mm_ffn_down_w, model.mm_ffn_down_b,
111 hparams.ffn_op, -1);
112 cb(cur, "after_ffn_proj", -1);
113 // cb(ggml_sum(ctx0, cur), "merged_sum", -1);
114 }
115
116 // build the graph
117 ggml_build_forward_expand(gf, cur);
118
119 return gf;
120}