1#include "arg.h"
2
3#include "common.h"
4#include "gguf.h" // for reading GGUF splits
5#include "log.h"
6#include "download.h"
7
8#define JSON_ASSERT GGML_ASSERT
9#include <nlohmann/json.hpp>
10
11#include <algorithm>
12#include <filesystem>
13#include <fstream>
14#include <future>
15#include <map>
16#include <mutex>
17#include <regex>
18#include <string>
19#include <thread>
20#include <vector>
21
22#if defined(LLAMA_USE_HTTPLIB)
23#include "http.h"
24#endif
25
26#ifndef __EMSCRIPTEN__
27#ifdef __linux__
28#include <linux/limits.h>
29#elif defined(_WIN32)
30# if !defined(PATH_MAX)
31# define PATH_MAX MAX_PATH
32# endif
33#elif defined(_AIX)
34#include <sys/limits.h>
35#else
36#include <sys/syslimits.h>
37#endif
38#endif
39
40#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
41
42// isatty
43#if defined(_WIN32)
44#include <io.h>
45#else
46#include <unistd.h>
47#endif
48
49using json = nlohmann::ordered_json;
50
51//
52// downloader
53//
54
55// validate repo name format: owner/repo
56static bool validate_repo_name(const std::string & repo) {
57 static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
58 return std::regex_match(repo, repo_regex);
59}
60
61static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
62 // we use "=" to avoid clashing with other component, while still being allowed on windows
63 std::string fname = "manifest=" + repo + "=" + tag + ".json";
64 if (!validate_repo_name(repo)) {
65 throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
66 }
67 string_replace_all(fname, "/", "=");
68 return fs_get_cache_file(fname);
69}
70
71static std::string read_file(const std::string & fname) {
72 std::ifstream file(fname);
73 if (!file) {
74 throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
75 }
76 std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
77 file.close();
78 return content;
79}
80
81static void write_file(const std::string & fname, const std::string & content) {
82 const std::string fname_tmp = fname + ".tmp";
83 std::ofstream file(fname_tmp);
84 if (!file) {
85 throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
86 }
87
88 try {
89 file << content;
90 file.close();
91
92 // Makes write atomic
93 if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
94 LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
95 // If rename fails, try to delete the temporary file
96 if (remove(fname_tmp.c_str()) != 0) {
97 LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
98 }
99 }
100 } catch (...) {
101 // If anything fails, try to delete the temporary file
102 if (remove(fname_tmp.c_str()) != 0) {
103 LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
104 }
105
106 throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
107 }
108}
109
110static void write_etag(const std::string & path, const std::string & etag) {
111 const std::string etag_path = path + ".etag";
112 write_file(etag_path, etag);
113 LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
114}
115
116static std::string read_etag(const std::string & path) {
117 std::string none;
118 const std::string etag_path = path + ".etag";
119
120 if (std::filesystem::exists(etag_path)) {
121 std::ifstream etag_in(etag_path);
122 if (!etag_in) {
123 LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
124 return none;
125 }
126 std::string etag;
127 std::getline(etag_in, etag);
128 return etag;
129 }
130
131 // no etag file, but maybe there is an old .json
132 // remove this code later
133 const std::string metadata_path = path + ".json";
134
135 if (std::filesystem::exists(metadata_path)) {
136 std::ifstream metadata_in(metadata_path);
137 try {
138 nlohmann::json metadata_json;
139 metadata_in >> metadata_json;
140 LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
141 metadata_json.dump().c_str());
142 if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
143 std::string etag = metadata_json.at("etag");
144 write_etag(path, etag);
145 if (!std::filesystem::remove(metadata_path)) {
146 LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
147 }
148 return etag;
149 }
150 } catch (const nlohmann::json::exception & e) {
151 LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
152 }
153 }
154 return none;
155}
156
157static bool is_http_status_ok(int status) {
158 return status >= 200 && status < 400;
159}
160
161std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
162 auto parts = string_split<std::string>(hf_repo_with_tag, ':');
163 std::string tag = parts.size() > 1 ? parts.back() : "latest";
164 std::string hf_repo = parts[0];
165 if (string_split<std::string>(hf_repo, '/').size() != 2) {
166 throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
167 }
168 return {hf_repo, tag};
169}
170
171#if defined(LLAMA_USE_HTTPLIB)
172
173class ProgressBar {
174 static inline std::mutex mutex;
175 static inline std::map<const ProgressBar *, int> lines;
176 static inline int max_line = 0;
177
178 static void cleanup(const ProgressBar * line) {
179 lines.erase(line);
180 if (lines.empty()) {
181 max_line = 0;
182 }
183 }
184
185 static bool is_output_a_tty() {
186#if defined(_WIN32)
187 return _isatty(_fileno(stdout));
188#else
189 return isatty(1);
190#endif
191 }
192
193public:
194 ProgressBar() = default;
195
196 ~ProgressBar() {
197 std::lock_guard<std::mutex> lock(mutex);
198 cleanup(this);
199 }
200
201 void update(size_t current, size_t total) {
202 if (!is_output_a_tty()) {
203 return;
204 }
205
206 if (!total) {
207 return;
208 }
209
210 std::lock_guard<std::mutex> lock(mutex);
211
212 if (lines.find(this) == lines.end()) {
213 lines[this] = max_line++;
214 std::cout << "\n";
215 }
216 int lines_up = max_line - lines[this];
217
218 size_t width = 50;
219 size_t pct = (100 * current) / total;
220 size_t pos = (width * current) / total;
221
222 std::cout << "\033[s";
223
224 if (lines_up > 0) {
225 std::cout << "\033[" << lines_up << "A";
226 }
227 std::cout << "\033[2K\r["
228 << std::string(pos, '=')
229 << (pos < width ? ">" : "")
230 << std::string(width - pos, ' ')
231 << "] " << std::setw(3) << pct << "% ("
232 << current / (1024 * 1024) << " MB / "
233 << total / (1024 * 1024) << " MB) "
234 << "\033[u";
235
236 std::cout.flush();
237
238 if (current == total) {
239 cleanup(this);
240 }
241 }
242
243 ProgressBar(const ProgressBar &) = delete;
244 ProgressBar & operator=(const ProgressBar &) = delete;
245};
246
247static bool common_pull_file(httplib::Client & cli,
248 const std::string & resolve_path,
249 const std::string & path_tmp,
250 bool supports_ranges,
251 size_t existing_size,
252 size_t & total_size) {
253 std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
254 if (!ofs.is_open()) {
255 LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
256 return false;
257 }
258
259 httplib::Headers headers;
260 if (supports_ranges && existing_size > 0) {
261 headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
262 }
263
264 const char * func = __func__; // avoid __func__ inside a lambda
265 size_t downloaded = existing_size;
266 size_t progress_step = 0;
267 ProgressBar bar;
268
269 auto res = cli.Get(resolve_path, headers,
270 [&](const httplib::Response &response) {
271 if (existing_size > 0 && response.status != 206) {
272 LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
273 return false;
274 }
275 if (existing_size == 0 && response.status != 200) {
276 LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
277 return false;
278 }
279 if (total_size == 0 && response.has_header("Content-Length")) {
280 try {
281 size_t content_length = std::stoull(response.get_header_value("Content-Length"));
282 total_size = existing_size + content_length;
283 } catch (const std::exception &e) {
284 LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
285 }
286 }
287 return true;
288 },
289 [&](const char *data, size_t len) {
290 ofs.write(data, len);
291 if (!ofs) {
292 LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
293 return false;
294 }
295 downloaded += len;
296 progress_step += len;
297
298 if (progress_step >= total_size / 1000 || downloaded == total_size) {
299 bar.update(downloaded, total_size);
300 progress_step = 0;
301 }
302 return true;
303 },
304 nullptr
305 );
306
307 if (!res) {
308 LOG_ERR("%s: download failed: %s (status: %d)\n",
309 __func__,
310 httplib::to_string(res.error()).c_str(),
311 res ? res->status : -1);
312 return false;
313 }
314
315 return true;
316}
317
318// download one single file from remote URL to local path
319// returns status code or -1 on error
320static int common_download_file_single_online(const std::string & url,
321 const std::string & path,
322 const std::string & bearer_token,
323 const common_header_list & custom_headers) {
324 static const int max_attempts = 3;
325 static const int retry_delay_seconds = 2;
326
327 auto [cli, parts] = common_http_client(url);
328
329 httplib::Headers headers;
330 for (const auto & h : custom_headers) {
331 headers.emplace(h.first, h.second);
332 }
333 if (headers.find("User-Agent") == headers.end()) {
334 headers.emplace("User-Agent", "llama-cpp/" + build_info);
335 }
336 if (!bearer_token.empty()) {
337 headers.emplace("Authorization", "Bearer " + bearer_token);
338 }
339 cli.set_default_headers(headers);
340
341 const bool file_exists = std::filesystem::exists(path);
342
343 std::string last_etag;
344 if (file_exists) {
345 last_etag = read_etag(path);
346 } else {
347 LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
348 }
349
350 for (int i = 0; i < max_attempts; ++i) {
351 auto head = cli.Head(parts.path);
352 bool head_ok = head && head->status >= 200 && head->status < 300;
353 if (!head_ok) {
354 LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
355 if (file_exists) {
356 LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
357 return 304; // 304 Not Modified - fake cached response
358 }
359 return head->status; // cannot use cached file, return raw status code
360 // TODO: maybe retry only on certain codes
361 }
362
363 std::string etag;
364 if (head_ok && head->has_header("ETag")) {
365 etag = head->get_header_value("ETag");
366 }
367
368 size_t total_size = 0;
369 if (head_ok && head->has_header("Content-Length")) {
370 try {
371 total_size = std::stoull(head->get_header_value("Content-Length"));
372 } catch (const std::exception& e) {
373 LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
374 }
375 }
376
377 bool supports_ranges = false;
378 if (head_ok && head->has_header("Accept-Ranges")) {
379 supports_ranges = head->get_header_value("Accept-Ranges") != "none";
380 }
381
382 bool should_download_from_scratch = false;
383 if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
384 LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
385 last_etag.c_str(), etag.c_str());
386 should_download_from_scratch = true;
387 }
388
389 if (file_exists) {
390 if (!should_download_from_scratch) {
391 LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
392 return 304; // 304 Not Modified - fake cached response
393 }
394 LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
395 if (remove(path.c_str()) != 0) {
396 LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
397 return -1;
398 }
399 }
400
401 const std::string path_temporary = path + ".downloadInProgress";
402 size_t existing_size = 0;
403
404 if (std::filesystem::exists(path_temporary)) {
405 if (supports_ranges && !should_download_from_scratch) {
406 existing_size = std::filesystem::file_size(path_temporary);
407 } else if (remove(path_temporary.c_str()) != 0) {
408 LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
409 return -1;
410 }
411 }
412
413 // start the download
414 LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
415 __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
416 const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
417 if (!was_pull_successful) {
418 if (i + 1 < max_attempts) {
419 const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
420 LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
421 std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
422 } else {
423 LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
424 }
425 continue;
426 }
427
428 if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
429 LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
430 return -1;
431 }
432 if (!etag.empty()) {
433 write_etag(path, etag);
434 }
435
436 return head->status; // TODO: use actual GET status?
437 }
438
439 return -1; // max attempts reached
440}
441
442std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
443 const common_remote_params & params) {
444 auto [cli, parts] = common_http_client(url);
445
446 httplib::Headers headers;
447 for (const auto & h : params.headers) {
448 headers.emplace(h.first, h.second);
449 }
450 if (headers.find("User-Agent") == headers.end()) {
451 headers.emplace("User-Agent", "llama-cpp/" + build_info);
452 }
453
454 if (params.timeout > 0) {
455 cli.set_read_timeout(params.timeout, 0);
456 cli.set_write_timeout(params.timeout, 0);
457 }
458
459 std::vector<char> buf;
460 auto res = cli.Get(parts.path, headers,
461 [&](const char *data, size_t len) {
462 buf.insert(buf.end(), data, data + len);
463 return params.max_size == 0 ||
464 buf.size() <= static_cast<size_t>(params.max_size);
465 },
466 nullptr
467 );
468
469 if (!res) {
470 throw std::runtime_error("error: cannot make GET request");
471 }
472
473 return { res->status, std::move(buf) };
474}
475
476int common_download_file_single(const std::string & url,
477 const std::string & path,
478 const std::string & bearer_token,
479 bool offline,
480 const common_header_list & headers) {
481 if (!offline) {
482 return common_download_file_single_online(url, path, bearer_token, headers);
483 }
484
485 if (!std::filesystem::exists(path)) {
486 LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
487 return -1;
488 }
489
490 LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
491 return 304; // Not Modified - fake cached response
492}
493
494// download multiple files from remote URLs to local paths
495// the input is a vector of pairs <url, path>
496static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
497 const std::string & bearer_token,
498 bool offline,
499 const common_header_list & headers) {
500 // Prepare download in parallel
501 std::vector<std::future<bool>> futures_download;
502 futures_download.reserve(urls.size());
503
504 for (auto const & item : urls) {
505 futures_download.push_back(
506 std::async(
507 std::launch::async,
508 [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
509 const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
510 return is_http_status_ok(http_status);
511 },
512 item
513 )
514 );
515 }
516
517 // Wait for all downloads to complete
518 for (auto & f : futures_download) {
519 if (!f.get()) {
520 return false;
521 }
522 }
523
524 return true;
525}
526
527bool common_download_model(const common_params_model & model,
528 const std::string & bearer_token,
529 bool offline,
530 const common_header_list & headers) {
531 // Basic validation of the model.url
532 if (model.url.empty()) {
533 LOG_ERR("%s: invalid model url\n", __func__);
534 return false;
535 }
536
537 const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
538 if (!is_http_status_ok(http_status)) {
539 return false;
540 }
541
542 // check for additional GGUFs split to download
543 int n_split = 0;
544 {
545 struct gguf_init_params gguf_params = {
546 /*.no_alloc = */ true,
547 /*.ctx = */ NULL,
548 };
549 auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
550 if (!ctx_gguf) {
551 LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
552 return false;
553 }
554
555 auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
556 if (key_n_split >= 0) {
557 n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
558 }
559
560 gguf_free(ctx_gguf);
561 }
562
563 if (n_split > 1) {
564 char split_prefix[PATH_MAX] = {0};
565 char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
566
567 // Verify the first split file format
568 // and extract split URL and PATH prefixes
569 {
570 if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
571 LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
572 return false;
573 }
574
575 if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
576 LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
577 return false;
578 }
579 }
580
581 std::vector<std::pair<std::string, std::string>> urls;
582 for (int idx = 1; idx < n_split; idx++) {
583 char split_path[PATH_MAX] = {0};
584 llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
585
586 char split_url[LLAMA_MAX_URL_LENGTH] = {0};
587 llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
588
589 if (std::string(split_path) == model.path) {
590 continue; // skip the already downloaded file
591 }
592
593 urls.push_back({split_url, split_path});
594 }
595
596 // Download in parallel
597 common_download_file_multiple(urls, bearer_token, offline, headers);
598 }
599
600 return true;
601}
602
603common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
604 const std::string & bearer_token,
605 bool offline,
606 const common_header_list & custom_headers) {
607 // the returned hf_repo is without tag
608 auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
609
610 std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
611
612 // headers
613 common_header_list headers = custom_headers;
614 headers.push_back({"Accept", "application/json"});
615 if (!bearer_token.empty()) {
616 headers.push_back({"Authorization", "Bearer " + bearer_token});
617 }
618 // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
619 // User-Agent header is already set in common_remote_get_content, no need to set it here
620
621 // make the request
622 common_remote_params params;
623 params.headers = headers;
624 long res_code = 0;
625 std::string res_str;
626 bool use_cache = false;
627 std::string cached_response_path = get_manifest_path(hf_repo, tag);
628 if (!offline) {
629 try {
630 auto res = common_remote_get_content(url, params);
631 res_code = res.first;
632 res_str = std::string(res.second.data(), res.second.size());
633 } catch (const std::exception & e) {
634 LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
635 }
636 }
637 if (res_code == 0) {
638 if (std::filesystem::exists(cached_response_path)) {
639 LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
640 res_str = read_file(cached_response_path);
641 res_code = 200;
642 use_cache = true;
643 } else {
644 throw std::runtime_error(
645 offline ? "error: failed to get manifest (offline mode)"
646 : "error: failed to get manifest (check your internet connection)");
647 }
648 }
649 std::string ggufFile;
650 std::string mmprojFile;
651
652 if (res_code == 200 || res_code == 304) {
653 try {
654 auto j = json::parse(res_str);
655
656 if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
657 ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
658 }
659 if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
660 mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
661 }
662 } catch (const std::exception & e) {
663 throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
664 }
665 if (!use_cache) {
666 // if not using cached response, update the cache file
667 write_file(cached_response_path, res_str);
668 }
669 } else if (res_code == 401) {
670 throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
671 } else {
672 throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
673 }
674
675 // check response
676 if (ggufFile.empty()) {
677 throw std::runtime_error("error: model does not have ggufFile");
678 }
679
680 return { hf_repo, ggufFile, mmprojFile };
681}
682
683//
684// Docker registry functions
685//
686
687static std::string common_docker_get_token(const std::string & repo) {
688 std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
689
690 common_remote_params params;
691 auto res = common_remote_get_content(url, params);
692
693 if (res.first != 200) {
694 throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
695 }
696
697 std::string response_str(res.second.begin(), res.second.end());
698 nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
699
700 if (!response.contains("token")) {
701 throw std::runtime_error("Docker registry token response missing 'token' field");
702 }
703
704 return response["token"].get<std::string>();
705}
706
707std::string common_docker_resolve_model(const std::string & docker) {
708 // Parse ai/smollm2:135M-Q4_0
709 size_t colon_pos = docker.find(':');
710 std::string repo, tag;
711 if (colon_pos != std::string::npos) {
712 repo = docker.substr(0, colon_pos);
713 tag = docker.substr(colon_pos + 1);
714 } else {
715 repo = docker;
716 tag = "latest";
717 }
718
719 // ai/ is the default
720 size_t slash_pos = docker.find('/');
721 if (slash_pos == std::string::npos) {
722 repo.insert(0, "ai/");
723 }
724
725 LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
726 try {
727 // --- helper: digest validation ---
728 auto validate_oci_digest = [](const std::string & digest) -> std::string {
729 // Expected: algo:hex ; start with sha256 (64 hex chars)
730 // You can extend this map if supporting other algorithms in future.
731 static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
732 std::smatch m;
733 if (!std::regex_match(digest, m, re)) {
734 throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
735 }
736 // normalize hex to lowercase
737 std::string normalized = digest;
738 std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
739 return std::tolower(c);
740 });
741 return normalized;
742 };
743
744 std::string token = common_docker_get_token(repo); // Get authentication token
745
746 // Get manifest
747 // TODO: cache the manifest response so that it appears in the model list
748 const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
749 std::string manifest_url = url_prefix + "/manifests/" + tag;
750 common_remote_params manifest_params;
751 manifest_params.headers.push_back({"Authorization", "Bearer " + token});
752 manifest_params.headers.push_back({"Accept",
753 "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
754 });
755 auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
756 if (manifest_res.first != 200) {
757 throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
758 }
759
760 std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
761 nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
762 std::string gguf_digest; // Find the GGUF layer
763 if (manifest.contains("layers")) {
764 for (const auto & layer : manifest["layers"]) {
765 if (layer.contains("mediaType")) {
766 std::string media_type = layer["mediaType"].get<std::string>();
767 if (media_type == "application/vnd.docker.ai.gguf.v3" ||
768 media_type.find("gguf") != std::string::npos) {
769 gguf_digest = layer["digest"].get<std::string>();
770 break;
771 }
772 }
773 }
774 }
775
776 if (gguf_digest.empty()) {
777 throw std::runtime_error("No GGUF layer found in Docker manifest");
778 }
779
780 // Validate & normalize digest
781 gguf_digest = validate_oci_digest(gguf_digest);
782 LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
783
784 // Prepare local filename
785 std::string model_filename = repo;
786 std::replace(model_filename.begin(), model_filename.end(), '/', '_');
787 model_filename += "_" + tag + ".gguf";
788 std::string local_path = fs_get_cache_file(model_filename);
789
790 const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
791 const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
792 if (!is_http_status_ok(http_status)) {
793 throw std::runtime_error("Failed to download Docker Model");
794 }
795
796 LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
797 return local_path;
798 } catch (const std::exception & e) {
799 LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
800 throw;
801 }
802}
803
804#else
805
806common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
807 throw std::runtime_error("download functionality is not enabled in this build");
808}
809
810bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
811 throw std::runtime_error("download functionality is not enabled in this build");
812}
813
814std::string common_docker_resolve_model(const std::string &) {
815 throw std::runtime_error("download functionality is not enabled in this build");
816}
817
818int common_download_file_single(const std::string &,
819 const std::string &,
820 const std::string &,
821 bool,
822 const common_header_list &) {
823 throw std::runtime_error("download functionality is not enabled in this build");
824}
825
826#endif // defined(LLAMA_USE_HTTPLIB)
827
828std::vector<common_cached_model_info> common_list_cached_models() {
829 std::vector<common_cached_model_info> models;
830 const std::string cache_dir = fs_get_cache_directory();
831 const std::vector<common_file_info> files = fs_list(cache_dir, false);
832 for (const auto & file : files) {
833 if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
834 common_cached_model_info model_info;
835 model_info.manifest_path = file.path;
836 std::string fname = file.name;
837 string_replace_all(fname, ".json", ""); // remove extension
838 auto parts = string_split<std::string>(fname, '=');
839 if (parts.size() == 4) {
840 // expect format: manifest=<user>=<model>=<tag>=<other>
841 model_info.user = parts[1];
842 model_info.model = parts[2];
843 model_info.tag = parts[3];
844 } else {
845 // invalid format
846 continue;
847 }
848 model_info.size = 0; // TODO: get GGUF size, not manifest size
849 models.push_back(model_info);
850 }
851 }
852 return models;
853}