1#include "ggml-metal-common.h"
  2
  3#include "ggml-impl.h"
  4#include "ggml-backend-impl.h"
  5
  6#include <vector>
  7
  8// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
  9// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
 10struct ggml_mem_range {
 11    uint64_t pb; // buffer id
 12
 13    uint64_t p0; // begin
 14    uint64_t p1; // end
 15
 16    ggml_mem_range_type pt;
 17};
 18
 19struct ggml_mem_ranges {
 20    std::vector<ggml_mem_range> ranges;
 21
 22    int debug = 0;
 23};
 24
 25ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
 26    auto * res = new ggml_mem_ranges;
 27
 28    res->ranges.reserve(256);
 29    res->debug = debug;
 30
 31    return res;
 32}
 33
 34void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
 35    delete mrs;
 36}
 37
 38void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
 39    mrs->ranges.clear();
 40}
 41
 42static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
 43    mrs->ranges.push_back(mr);
 44
 45    return true;
 46}
 47
 48static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
 49    // always use the base tensor
 50    tensor = tensor->view_src ? tensor->view_src : tensor;
 51
 52    GGML_ASSERT(!tensor->view_src);
 53
 54    ggml_mem_range mr;
 55
 56    if (tensor->buffer) {
 57        // when the tensor is allocated, use the actual memory address range in the buffer
 58        //
 59        // take the actual allocated size with ggml_backend_buft_get_alloc_size()
 60        // this can be larger than the tensor size if the buffer type allocates extra memory
 61        // ref: https://github.com/ggml-org/llama.cpp/pull/15966
 62        mr = {
 63            /*.pb =*/ (uint64_t) tensor->buffer,
 64            /*.p0 =*/ (uint64_t) tensor->data,
 65            /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
 66            /*.pt =*/ pt,
 67        };
 68    } else {
 69        // otherwise, the pointer address is used as an unique id of the memory ranges
 70        //   that the tensor will be using when it is allocated
 71        mr = {
 72            /*.pb =*/ (uint64_t) tensor,
 73            /*.p0 =*/ 0,    //
 74            /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
 75            /*.pt =*/ pt,
 76        };
 77    };
 78
 79    return mr;
 80}
 81
 82static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
 83    return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
 84}
 85
 86static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
 87    return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
 88}
 89
 90static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
 91    GGML_ASSERT(tensor);
 92
 93    ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
 94
 95    if (mrs->debug > 2) {
 96        GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
 97    }
 98
 99    return ggml_mem_ranges_add(mrs, mr);
100}
101
102static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
103    GGML_ASSERT(tensor);
104
105    ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
106
107    if (mrs->debug > 2) {
108        GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
109    }
110
111    return ggml_mem_ranges_add(mrs, mr);
112}
113
114bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
115    for (int i = 0; i < GGML_MAX_SRC; i++) {
116        if (tensor->src[i]) {
117            ggml_mem_ranges_add_src(mrs, tensor->src[i]);
118        }
119    }
120
121    return ggml_mem_ranges_add_dst(mrs, tensor);
122}
123
124static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
125    for (size_t i = 0; i < mrs->ranges.size(); i++) {
126        const auto & cmp = mrs->ranges[i];
127
128        // two memory ranges cannot intersect if they are in different buffers
129        if (mr.pb != cmp.pb) {
130            continue;
131        }
132
133        // intersecting source ranges are allowed
134        if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
135            continue;
136        }
137
138        if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
139            if (mrs->debug > 2) {
140                GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
141                        __func__,
142                        mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
143                        mr.pb, mr.p0, mr.p1,
144                        cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
145                        cmp.pb, cmp.p0, cmp.p1);
146            }
147
148            return false;
149        }
150    }
151
152    return true;
153}
154
155static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
156    GGML_ASSERT(tensor);
157
158    ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
159
160    const bool res = ggml_mem_ranges_check(mrs, mr);
161
162    return res;
163}
164
165static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
166    GGML_ASSERT(tensor);
167
168    ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
169
170    const bool res = ggml_mem_ranges_check(mrs, mr);
171
172    return res;
173}
174
175bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
176    for (int i = 0; i < GGML_MAX_SRC; i++) {
177        if (tensor->src[i]) {
178            if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
179                return false;
180            }
181        }
182    }
183
184    return ggml_mem_ranges_check_dst(mrs, tensor);
185}
186
187struct node_info {
188    ggml_tensor * node;
189
190    std::vector<ggml_tensor *> fused;
191
192    ggml_op op() const {
193        return node->op;
194    }
195
196    const ggml_tensor * dst() const {
197        return fused.empty() ? node : fused.back();
198    }
199
200    bool is_empty() const {
201        return ggml_op_is_empty(node->op);
202    }
203
204    void add_fused(ggml_tensor * t) {
205        fused.push_back(t);
206    }
207};
208
209static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
210    // helper to add node src and dst ranges
211    const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
212        for (int i = 0; i < GGML_MAX_SRC; i++) {
213            if (node.node->src[i]) {
214                if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
215                    return false;
216                }
217            }
218        }
219
220        // keep track of the sources of the fused nodes as well
221        for (const auto * fused : node.fused) {
222            for (int i = 0; i < GGML_MAX_SRC; i++) {
223                if (fused->src[i]) {
224                    if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
225                        return false;
226                    }
227                }
228            }
229        }
230
231        return ggml_mem_ranges_add_dst(mrs, node.dst());
232    };
233
234    // helper to check if a node can run concurrently with the existing set of nodes
235    const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
236        for (int i = 0; i < GGML_MAX_SRC; i++) {
237            if (node.node->src[i]) {
238                if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
239                    return false;
240                }
241            }
242        }
243
244        for (const auto * fused : node.fused) {
245            for (int i = 0; i < GGML_MAX_SRC; i++) {
246                if (fused->src[i]) {
247                    if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
248                        return false;
249                    }
250                }
251            }
252        }
253
254        return ggml_mem_ranges_check_dst(mrs, node.dst());
255    };
256
257    // perform reorders only across these types of ops
258    // can be expanded when needed
259    const auto & h_safe = [](ggml_op op) {
260        switch (op) {
261            case GGML_OP_MUL_MAT:
262            case GGML_OP_MUL_MAT_ID:
263            case GGML_OP_ROPE:
264            case GGML_OP_NORM:
265            case GGML_OP_RMS_NORM:
266            case GGML_OP_GROUP_NORM:
267            case GGML_OP_SUM_ROWS:
268            case GGML_OP_MUL:
269            case GGML_OP_ADD:
270            case GGML_OP_DIV:
271            case GGML_OP_GLU:
272            case GGML_OP_SCALE:
273            case GGML_OP_GET_ROWS:
274            case GGML_OP_CPY:
275            case GGML_OP_SET_ROWS:
276                return true;
277            default:
278                return ggml_op_is_empty(op);
279        }
280    };
281
282    const int n = nodes.size();
283
284    std::vector<int> res;
285    res.reserve(n);
286
287    std::vector<bool> used(n, false);
288
289    // the memory ranges for the set of currently concurrent nodes
290    ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);
291
292    // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
293    ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);
294
295    for (int i0 = 0; i0 < n; i0++) {
296        if (used[i0]) {
297            continue;
298        }
299
300        const auto & node0 = nodes[i0];
301
302        // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
303        // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
304        //
305        // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
306        if (!node0.is_empty() && !h_check(mrs0, node0)) {
307            // this will hold the set of memory ranges from the nodes that haven't been processed yet
308            // if a node is not concurrent with this set, we cannot reorder it
309            ggml_mem_ranges_reset(mrs1);
310
311            // initialize it with the current node
312            h_add(mrs1, node0);
313
314            // that many nodes forward to search for a concurrent node
315            constexpr int N_FORWARD = 8;
316
317            for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
318                if (used[i1]) {
319                    continue;
320                }
321
322                const auto & node1 = nodes[i1];
323
324                // disallow reordering of certain ops
325                if (!h_safe(node1.op())) {
326                    break;
327                }
328
329                const bool is_empty = node1.is_empty();
330
331                // to reorder a node and add it to the concurrent set, it has to be:
332                //   + empty or concurrent with all nodes in the existing concurrent set (mrs0)
333                //   + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
334                if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
335                    // add the node to the existing concurrent set (i.e. reorder it for early execution)
336                    h_add(mrs0, node1);
337                    res.push_back(i1);
338
339                    // mark as used, so we skip re-processing it later
340                    used[i1] = true;
341                } else {
342                    // expand the set of nodes that haven't been processed yet
343                    h_add(mrs1, node1);
344                }
345            }
346
347            // finalize the concurrent set and begin a new one
348            ggml_mem_ranges_reset(mrs0);
349        }
350
351        // expand the concurrent set with the current node
352        {
353            h_add(mrs0, node0);
354            res.push_back(i0);
355        }
356    }
357
358    ggml_mem_ranges_free(mrs0);
359    ggml_mem_ranges_free(mrs1);
360
361    return res;
362}
363
364void ggml_graph_optimize(ggml_cgraph * gf) {
365    constexpr int MAX_FUSE = 16;
366
367    const int n = gf->n_nodes;
368
369    enum ggml_op ops[MAX_FUSE];
370
371    std::vector<node_info> nodes;
372    nodes.reserve(gf->n_nodes);
373
374    // fuse nodes:
375    // we don't want to make reorders that break fusing, so we first pack all fusable tensors
376    //   and perform the reorder over the fused nodes. after the reorder is done, we unfuse
377    for (int i = 0; i < n; i++) {
378        node_info node = {
379            /*.node =*/ gf->nodes[i],
380            /*.fused =*/ {},
381        };
382
383        // fuse only ops that start with these operations
384        // can be expanded when needed
385        if (node.op() == GGML_OP_ADD ||
386            node.op() == GGML_OP_NORM ||
387            node.op() == GGML_OP_RMS_NORM) {
388            ops[0] = node.op();
389
390            int f = i + 1;
391            while (f < n && f < i + MAX_FUSE) {
392                // conservatively allow fusing only these ops
393                // can be expanded when needed
394                if (gf->nodes[f]->op != GGML_OP_ADD &&
395                    gf->nodes[f]->op != GGML_OP_MUL &&
396                    gf->nodes[f]->op != GGML_OP_NORM &&
397                    gf->nodes[f]->op != GGML_OP_RMS_NORM) {
398                    break;
399                }
400                ops[f - i] = gf->nodes[f]->op;
401                f++;
402            }
403
404            f -= i;
405            for (; f > 1; f--) {
406                if (ggml_can_fuse(gf, i, ops, f)) {
407                    break;
408                }
409            }
410
411            // add the fused tensors into the node info so we can unfuse them later
412            for (int k = 1; k < f; k++) {
413                ++i;
414
415                // the .dst() becomes the last fused tensor
416                node.add_fused(gf->nodes[i]);
417            }
418        }
419
420        nodes.push_back(std::move(node));
421    }
422
423#if 1
424    // reorder to improve concurrency
425    const auto order = ggml_metal_graph_optimize_reorder(nodes);
426#else
427    std::vector<int> order(nodes.size());
428    for (size_t i = 0; i < nodes.size(); i++) {
429        order[i] = i;
430    }
431#endif
432
433    // unfuse
434    {
435        int j = 0;
436        for (const auto i : order) {
437            const auto & node = nodes[i];
438
439            gf->nodes[j++] = node.node;
440
441            for (auto * fused : node.fused) {
442                gf->nodes[j++] = fused;
443            }
444        }
445    }
446}