1#include "ggml-metal-common.h"
2
3#include "ggml-impl.h"
4#include "ggml-backend-impl.h"
5
6#include <vector>
7
8// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
9// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
10struct ggml_mem_range {
11 uint64_t pb; // buffer id
12
13 uint64_t p0; // begin
14 uint64_t p1; // end
15
16 ggml_mem_range_type pt;
17};
18
19struct ggml_mem_ranges {
20 std::vector<ggml_mem_range> ranges;
21
22 int debug = 0;
23};
24
25ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
26 auto * res = new ggml_mem_ranges;
27
28 res->ranges.reserve(256);
29 res->debug = debug;
30
31 return res;
32}
33
34void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
35 delete mrs;
36}
37
38void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
39 mrs->ranges.clear();
40}
41
42static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
43 mrs->ranges.push_back(mr);
44
45 return true;
46}
47
48static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
49 // always use the base tensor
50 tensor = tensor->view_src ? tensor->view_src : tensor;
51
52 GGML_ASSERT(!tensor->view_src);
53
54 ggml_mem_range mr;
55
56 if (tensor->buffer) {
57 // when the tensor is allocated, use the actual memory address range in the buffer
58 //
59 // take the actual allocated size with ggml_backend_buft_get_alloc_size()
60 // this can be larger than the tensor size if the buffer type allocates extra memory
61 // ref: https://github.com/ggml-org/llama.cpp/pull/15966
62 mr = {
63 /*.pb =*/ (uint64_t) tensor->buffer,
64 /*.p0 =*/ (uint64_t) tensor->data,
65 /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
66 /*.pt =*/ pt,
67 };
68 } else {
69 // otherwise, the pointer address is used as an unique id of the memory ranges
70 // that the tensor will be using when it is allocated
71 mr = {
72 /*.pb =*/ (uint64_t) tensor,
73 /*.p0 =*/ 0, //
74 /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
75 /*.pt =*/ pt,
76 };
77 };
78
79 return mr;
80}
81
82static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
83 return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
84}
85
86static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
87 return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
88}
89
90static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
91 GGML_ASSERT(tensor);
92
93 ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
94
95 if (mrs->debug > 2) {
96 GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
97 }
98
99 return ggml_mem_ranges_add(mrs, mr);
100}
101
102static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
103 GGML_ASSERT(tensor);
104
105 ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
106
107 if (mrs->debug > 2) {
108 GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
109 }
110
111 return ggml_mem_ranges_add(mrs, mr);
112}
113
114bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
115 for (int i = 0; i < GGML_MAX_SRC; i++) {
116 if (tensor->src[i]) {
117 ggml_mem_ranges_add_src(mrs, tensor->src[i]);
118 }
119 }
120
121 return ggml_mem_ranges_add_dst(mrs, tensor);
122}
123
124static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
125 for (size_t i = 0; i < mrs->ranges.size(); i++) {
126 const auto & cmp = mrs->ranges[i];
127
128 // two memory ranges cannot intersect if they are in different buffers
129 if (mr.pb != cmp.pb) {
130 continue;
131 }
132
133 // intersecting source ranges are allowed
134 if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
135 continue;
136 }
137
138 if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
139 if (mrs->debug > 2) {
140 GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
141 __func__,
142 mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
143 mr.pb, mr.p0, mr.p1,
144 cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
145 cmp.pb, cmp.p0, cmp.p1);
146 }
147
148 return false;
149 }
150 }
151
152 return true;
153}
154
155static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
156 GGML_ASSERT(tensor);
157
158 ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
159
160 const bool res = ggml_mem_ranges_check(mrs, mr);
161
162 return res;
163}
164
165static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
166 GGML_ASSERT(tensor);
167
168 ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
169
170 const bool res = ggml_mem_ranges_check(mrs, mr);
171
172 return res;
173}
174
175bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
176 for (int i = 0; i < GGML_MAX_SRC; i++) {
177 if (tensor->src[i]) {
178 if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
179 return false;
180 }
181 }
182 }
183
184 return ggml_mem_ranges_check_dst(mrs, tensor);
185}
186
187struct node_info {
188 ggml_tensor * node;
189
190 std::vector<ggml_tensor *> fused;
191
192 ggml_op op() const {
193 return node->op;
194 }
195
196 const ggml_tensor * dst() const {
197 return fused.empty() ? node : fused.back();
198 }
199
200 bool is_empty() const {
201 return ggml_op_is_empty(node->op);
202 }
203
204 void add_fused(ggml_tensor * t) {
205 fused.push_back(t);
206 }
207};
208
209static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
210 // helper to add node src and dst ranges
211 const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
212 for (int i = 0; i < GGML_MAX_SRC; i++) {
213 if (node.node->src[i]) {
214 if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
215 return false;
216 }
217 }
218 }
219
220 // keep track of the sources of the fused nodes as well
221 for (const auto * fused : node.fused) {
222 for (int i = 0; i < GGML_MAX_SRC; i++) {
223 if (fused->src[i]) {
224 if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
225 return false;
226 }
227 }
228 }
229 }
230
231 return ggml_mem_ranges_add_dst(mrs, node.dst());
232 };
233
234 // helper to check if a node can run concurrently with the existing set of nodes
235 const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
236 for (int i = 0; i < GGML_MAX_SRC; i++) {
237 if (node.node->src[i]) {
238 if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
239 return false;
240 }
241 }
242 }
243
244 for (const auto * fused : node.fused) {
245 for (int i = 0; i < GGML_MAX_SRC; i++) {
246 if (fused->src[i]) {
247 if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
248 return false;
249 }
250 }
251 }
252 }
253
254 return ggml_mem_ranges_check_dst(mrs, node.dst());
255 };
256
257 // perform reorders only across these types of ops
258 // can be expanded when needed
259 const auto & h_safe = [](ggml_op op) {
260 switch (op) {
261 case GGML_OP_MUL_MAT:
262 case GGML_OP_MUL_MAT_ID:
263 case GGML_OP_ROPE:
264 case GGML_OP_NORM:
265 case GGML_OP_RMS_NORM:
266 case GGML_OP_GROUP_NORM:
267 case GGML_OP_SUM_ROWS:
268 case GGML_OP_MUL:
269 case GGML_OP_ADD:
270 case GGML_OP_DIV:
271 case GGML_OP_GLU:
272 case GGML_OP_SCALE:
273 case GGML_OP_GET_ROWS:
274 case GGML_OP_CPY:
275 case GGML_OP_SET_ROWS:
276 return true;
277 default:
278 return ggml_op_is_empty(op);
279 }
280 };
281
282 const int n = nodes.size();
283
284 std::vector<int> res;
285 res.reserve(n);
286
287 std::vector<bool> used(n, false);
288
289 // the memory ranges for the set of currently concurrent nodes
290 ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);
291
292 // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
293 ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);
294
295 for (int i0 = 0; i0 < n; i0++) {
296 if (used[i0]) {
297 continue;
298 }
299
300 const auto & node0 = nodes[i0];
301
302 // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
303 // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
304 //
305 // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
306 if (!node0.is_empty() && !h_check(mrs0, node0)) {
307 // this will hold the set of memory ranges from the nodes that haven't been processed yet
308 // if a node is not concurrent with this set, we cannot reorder it
309 ggml_mem_ranges_reset(mrs1);
310
311 // initialize it with the current node
312 h_add(mrs1, node0);
313
314 // that many nodes forward to search for a concurrent node
315 constexpr int N_FORWARD = 8;
316
317 for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
318 if (used[i1]) {
319 continue;
320 }
321
322 const auto & node1 = nodes[i1];
323
324 // disallow reordering of certain ops
325 if (!h_safe(node1.op())) {
326 break;
327 }
328
329 const bool is_empty = node1.is_empty();
330
331 // to reorder a node and add it to the concurrent set, it has to be:
332 // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
333 // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
334 if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
335 // add the node to the existing concurrent set (i.e. reorder it for early execution)
336 h_add(mrs0, node1);
337 res.push_back(i1);
338
339 // mark as used, so we skip re-processing it later
340 used[i1] = true;
341 } else {
342 // expand the set of nodes that haven't been processed yet
343 h_add(mrs1, node1);
344 }
345 }
346
347 // finalize the concurrent set and begin a new one
348 ggml_mem_ranges_reset(mrs0);
349 }
350
351 // expand the concurrent set with the current node
352 {
353 h_add(mrs0, node0);
354 res.push_back(i0);
355 }
356 }
357
358 ggml_mem_ranges_free(mrs0);
359 ggml_mem_ranges_free(mrs1);
360
361 return res;
362}
363
364void ggml_graph_optimize(ggml_cgraph * gf) {
365 constexpr int MAX_FUSE = 16;
366
367 const int n = gf->n_nodes;
368
369 enum ggml_op ops[MAX_FUSE];
370
371 std::vector<node_info> nodes;
372 nodes.reserve(gf->n_nodes);
373
374 // fuse nodes:
375 // we don't want to make reorders that break fusing, so we first pack all fusable tensors
376 // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
377 for (int i = 0; i < n; i++) {
378 node_info node = {
379 /*.node =*/ gf->nodes[i],
380 /*.fused =*/ {},
381 };
382
383 // fuse only ops that start with these operations
384 // can be expanded when needed
385 if (node.op() == GGML_OP_ADD ||
386 node.op() == GGML_OP_NORM ||
387 node.op() == GGML_OP_RMS_NORM) {
388 ops[0] = node.op();
389
390 int f = i + 1;
391 while (f < n && f < i + MAX_FUSE) {
392 // conservatively allow fusing only these ops
393 // can be expanded when needed
394 if (gf->nodes[f]->op != GGML_OP_ADD &&
395 gf->nodes[f]->op != GGML_OP_MUL &&
396 gf->nodes[f]->op != GGML_OP_NORM &&
397 gf->nodes[f]->op != GGML_OP_RMS_NORM) {
398 break;
399 }
400 ops[f - i] = gf->nodes[f]->op;
401 f++;
402 }
403
404 f -= i;
405 for (; f > 1; f--) {
406 if (ggml_can_fuse(gf, i, ops, f)) {
407 break;
408 }
409 }
410
411 // add the fused tensors into the node info so we can unfuse them later
412 for (int k = 1; k < f; k++) {
413 ++i;
414
415 // the .dst() becomes the last fused tensor
416 node.add_fused(gf->nodes[i]);
417 }
418 }
419
420 nodes.push_back(std::move(node));
421 }
422
423#if 1
424 // reorder to improve concurrency
425 const auto order = ggml_metal_graph_optimize_reorder(nodes);
426#else
427 std::vector<int> order(nodes.size());
428 for (size_t i = 0; i < nodes.size(); i++) {
429 order[i] = i;
430 }
431#endif
432
433 // unfuse
434 {
435 int j = 0;
436 for (const auto i : order) {
437 const auto & node = nodes[i];
438
439 gf->nodes[j++] = node.node;
440
441 for (auto * fused : node.fused) {
442 gf->nodes[j++] = fused;
443 }
444 }
445 }
446}