1#pragma once
  2
  3#include "ggml.h"
  4#include "ggml-cpp.h"
  5#include "clip.h"
  6#include "clip-impl.h"
  7#include "clip-model.h"
  8
  9#include <vector>
 10#include <functional>
 11
 12#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)
 13
 14struct clip_graph {
 15    const clip_model & model;
 16    const clip_hparams & hparams;
 17    projector_type proj_type;
 18
 19    // we only support single image per batch
 20    const clip_image_f32 & img;
 21
 22    const int patch_size;
 23    const int n_patches_x;
 24    const int n_patches_y;
 25    const int n_patches;
 26    const int n_embd;
 27    const int n_head;
 28    const int d_head;
 29    const int n_layer;
 30    const int n_mmproj_embd;
 31    const float eps;
 32    const float kq_scale;
 33    const clip_flash_attn_type flash_attn_type;
 34
 35    ggml_context_ptr ctx0_ptr;
 36    ggml_context * ctx0;
 37    ggml_cgraph * gf;
 38
 39    clip_graph(clip_ctx * ctx, const clip_image_f32 & img);
 40
 41    virtual ~clip_graph() = default;
 42    virtual ggml_cgraph * build() = 0;
 43
 44    //
 45    // utility functions
 46    //
 47    void cb(ggml_tensor * cur0, const char * name, int il) const;
 48
 49    // siglip2 naflex
 50    ggml_tensor * resize_position_embeddings(uint32_t interpolation_mode = DEFAULT_INTERPOLATION_MODE);
 51
 52    // build vision transformer (ViT) cgraph
 53    // this function should cover most of the models
 54    // if your model has specific features, you should probably duplicate this function
 55    ggml_tensor * build_vit(
 56                ggml_tensor * inp,
 57                int64_t n_pos,
 58                norm_type norm_t,
 59                ffn_op_type ffn_t,
 60                ggml_tensor * learned_pos_embd,
 61                std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos);
 62
 63    // build the input after conv2d (inp_raw --> patches)
 64    // returns tensor with shape [n_embd, n_patches]
 65    ggml_tensor * build_inp();
 66
 67    ggml_tensor * build_inp_raw(int channels = 3);
 68
 69    ggml_tensor * build_norm(
 70            ggml_tensor * cur,
 71            ggml_tensor * mw,
 72            ggml_tensor * mb,
 73            norm_type type,
 74            float norm_eps,
 75            int il) const;
 76
 77    ggml_tensor * build_ffn(
 78            ggml_tensor * cur,
 79            ggml_tensor * up,
 80            ggml_tensor * up_b,
 81            ggml_tensor * gate,
 82            ggml_tensor * gate_b,
 83            ggml_tensor * down,
 84            ggml_tensor * down_b,
 85            ffn_op_type type_op,
 86            int il) const;
 87
 88    ggml_tensor * build_attn(
 89            ggml_tensor * wo,
 90            ggml_tensor * wo_b,
 91            ggml_tensor * q_cur,
 92            ggml_tensor * k_cur,
 93            ggml_tensor * v_cur,
 94            ggml_tensor * kq_mask,
 95            float kq_scale,
 96            int il) const;
 97
 98    // implementation of the 2D RoPE without adding a new op in ggml
 99    // this is not efficient (use double the memory), but works on all backends
100    // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
101    ggml_tensor * build_rope_2d(
102        ggml_context * ctx0,
103        ggml_tensor * cur,
104        ggml_tensor * pos_a, // first half
105        ggml_tensor * pos_b, // second half
106        const float freq_base,
107        const bool interleave_freq
108    );
109
110    // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
111    // support dynamic resolution
112    ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor);
113
114    // Generic function to stack frames for audio processing
115    // Abstracts out the StackAudioFrames logic used by ultravox
116    ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed);
117};