1#include "arg.h"
  2
  3#include "common.h"
  4#include "gguf.h" // for reading GGUF splits
  5#include "log.h"
  6#include "download.h"
  7
  8#define JSON_ASSERT GGML_ASSERT
  9#include <nlohmann/json.hpp>
 10
 11#include <algorithm>
 12#include <filesystem>
 13#include <fstream>
 14#include <future>
 15#include <map>
 16#include <mutex>
 17#include <regex>
 18#include <string>
 19#include <thread>
 20#include <vector>
 21
 22#if defined(LLAMA_USE_HTTPLIB)
 23#include "http.h"
 24#endif
 25
 26#ifndef __EMSCRIPTEN__
 27#ifdef __linux__
 28#include <linux/limits.h>
 29#elif defined(_WIN32)
 30#   if !defined(PATH_MAX)
 31#   define PATH_MAX MAX_PATH
 32#   endif
 33#elif defined(_AIX)
 34#include <sys/limits.h>
 35#else
 36#include <sys/syslimits.h>
 37#endif
 38#endif
 39
 40#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
 41
 42// isatty
 43#if defined(_WIN32)
 44#include <io.h>
 45#else
 46#include <unistd.h>
 47#endif
 48
 49using json = nlohmann::ordered_json;
 50
 51//
 52// downloader
 53//
 54
 55// validate repo name format: owner/repo
 56static bool validate_repo_name(const std::string & repo) {
 57    static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
 58    return std::regex_match(repo, repo_regex);
 59}
 60
 61static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
 62    // we use "=" to avoid clashing with other component, while still being allowed on windows
 63    std::string fname = "manifest=" + repo + "=" + tag + ".json";
 64    if (!validate_repo_name(repo)) {
 65        throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
 66    }
 67    string_replace_all(fname, "/", "=");
 68    return fs_get_cache_file(fname);
 69}
 70
 71static std::string read_file(const std::string & fname) {
 72    std::ifstream file(fname);
 73    if (!file) {
 74        throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
 75    }
 76    std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
 77    file.close();
 78    return content;
 79}
 80
 81static void write_file(const std::string & fname, const std::string & content) {
 82    const std::string fname_tmp = fname + ".tmp";
 83    std::ofstream     file(fname_tmp);
 84    if (!file) {
 85        throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
 86    }
 87
 88    try {
 89        file << content;
 90        file.close();
 91
 92        // Makes write atomic
 93        if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
 94            LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
 95            // If rename fails, try to delete the temporary file
 96            if (remove(fname_tmp.c_str()) != 0) {
 97                LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
 98            }
 99        }
100    } catch (...) {
101        // If anything fails, try to delete the temporary file
102        if (remove(fname_tmp.c_str()) != 0) {
103            LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
104        }
105
106        throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
107    }
108}
109
110static void write_etag(const std::string & path, const std::string & etag) {
111    const std::string etag_path = path + ".etag";
112    write_file(etag_path, etag);
113    LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
114}
115
116static std::string read_etag(const std::string & path) {
117    std::string none;
118    const std::string etag_path = path + ".etag";
119
120    if (std::filesystem::exists(etag_path)) {
121        std::ifstream etag_in(etag_path);
122        if (!etag_in) {
123            LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
124            return none;
125        }
126        std::string etag;
127        std::getline(etag_in, etag);
128        return etag;
129    }
130
131    // no etag file, but maybe there is an old .json
132    // remove this code later
133    const std::string metadata_path = path + ".json";
134
135    if (std::filesystem::exists(metadata_path)) {
136        std::ifstream metadata_in(metadata_path);
137        try {
138            nlohmann::json metadata_json;
139            metadata_in >> metadata_json;
140            LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
141                    metadata_json.dump().c_str());
142            if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
143                std::string etag = metadata_json.at("etag");
144                write_etag(path, etag);
145                if (!std::filesystem::remove(metadata_path)) {
146                    LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
147                }
148                return etag;
149            }
150        } catch (const nlohmann::json::exception & e) {
151            LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
152        }
153    }
154    return none;
155}
156
157static bool is_http_status_ok(int status) {
158    return status >= 200 && status < 400;
159}
160
161std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
162    auto parts = string_split<std::string>(hf_repo_with_tag, ':');
163    std::string tag = parts.size() > 1 ? parts.back() : "latest";
164    std::string hf_repo = parts[0];
165    if (string_split<std::string>(hf_repo, '/').size() != 2) {
166        throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
167    }
168    return {hf_repo, tag};
169}
170
171#if defined(LLAMA_USE_HTTPLIB)
172
173class ProgressBar {
174    static inline std::mutex mutex;
175    static inline std::map<const ProgressBar *, int> lines;
176    static inline int max_line = 0;
177
178    static void cleanup(const ProgressBar * line) {
179        lines.erase(line);
180        if (lines.empty()) {
181            max_line = 0;
182        }
183    }
184
185    static bool is_output_a_tty() {
186#if defined(_WIN32)
187        return _isatty(_fileno(stdout));
188#else
189        return isatty(1);
190#endif
191    }
192
193public:
194    ProgressBar() = default;
195
196    ~ProgressBar() {
197        std::lock_guard<std::mutex> lock(mutex);
198        cleanup(this);
199    }
200
201    void update(size_t current, size_t total) {
202        if (!is_output_a_tty()) {
203            return;
204        }
205
206        if (!total) {
207            return;
208        }
209
210        std::lock_guard<std::mutex> lock(mutex);
211
212        if (lines.find(this) == lines.end()) {
213            lines[this] = max_line++;
214            std::cout << "\n";
215        }
216        int lines_up = max_line - lines[this];
217
218        size_t width = 50;
219        size_t pct = (100 * current) / total;
220        size_t pos = (width * current) / total;
221
222        std::cout << "\033[s";
223
224        if (lines_up > 0) {
225            std::cout << "\033[" << lines_up << "A";
226        }
227        std::cout << "\033[2K\r["
228            << std::string(pos, '=')
229            << (pos < width ? ">" : "")
230            << std::string(width - pos, ' ')
231            << "] " << std::setw(3) << pct << "%  ("
232            << current / (1024 * 1024) << " MB / "
233            << total / (1024 * 1024) << " MB) "
234            << "\033[u";
235
236        std::cout.flush();
237
238        if (current == total) {
239             cleanup(this);
240        }
241    }
242
243    ProgressBar(const ProgressBar &) = delete;
244    ProgressBar & operator=(const ProgressBar &) = delete;
245};
246
247static bool common_pull_file(httplib::Client & cli,
248                             const std::string & resolve_path,
249                             const std::string & path_tmp,
250                             bool supports_ranges,
251                             size_t existing_size,
252                             size_t & total_size) {
253    std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
254    if (!ofs.is_open()) {
255        LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
256        return false;
257    }
258
259    httplib::Headers headers;
260    if (supports_ranges && existing_size > 0) {
261        headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
262    }
263
264    const char * func = __func__; // avoid __func__ inside a lambda
265    size_t downloaded = existing_size;
266    size_t progress_step = 0;
267    ProgressBar bar;
268
269    auto res = cli.Get(resolve_path, headers,
270        [&](const httplib::Response &response) {
271            if (existing_size > 0 && response.status != 206) {
272                LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
273                return false;
274            }
275            if (existing_size == 0 && response.status != 200) {
276                LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
277                return false;
278            }
279            if (total_size == 0 && response.has_header("Content-Length")) {
280                try {
281                    size_t content_length = std::stoull(response.get_header_value("Content-Length"));
282                    total_size = existing_size + content_length;
283                } catch (const std::exception &e) {
284                    LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
285                }
286            }
287            return true;
288        },
289        [&](const char *data, size_t len) {
290            ofs.write(data, len);
291            if (!ofs) {
292                LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
293                return false;
294            }
295            downloaded += len;
296            progress_step += len;
297
298            if (progress_step >= total_size / 1000 || downloaded == total_size) {
299                bar.update(downloaded, total_size);
300                progress_step = 0;
301            }
302            return true;
303        },
304        nullptr
305    );
306
307    if (!res) {
308        LOG_ERR("%s: download failed: %s (status: %d)\n",
309                __func__,
310                httplib::to_string(res.error()).c_str(),
311                res ? res->status : -1);
312        return false;
313    }
314
315    return true;
316}
317
318// download one single file from remote URL to local path
319// returns status code or -1 on error
320static int common_download_file_single_online(const std::string        & url,
321                                              const std::string        & path,
322                                              const std::string        & bearer_token,
323                                              const common_header_list & custom_headers) {
324    static const int max_attempts        = 3;
325    static const int retry_delay_seconds = 2;
326
327    auto [cli, parts] = common_http_client(url);
328
329    httplib::Headers headers;
330    for (const auto & h : custom_headers) {
331        headers.emplace(h.first, h.second);
332    }
333    if (headers.find("User-Agent") == headers.end()) {
334        headers.emplace("User-Agent", "llama-cpp/" + build_info);
335    }
336    if (!bearer_token.empty()) {
337        headers.emplace("Authorization", "Bearer " + bearer_token);
338    }
339    cli.set_default_headers(headers);
340
341    const bool file_exists = std::filesystem::exists(path);
342
343    std::string last_etag;
344    if (file_exists) {
345        last_etag = read_etag(path);
346    } else {
347        LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
348    }
349
350    for (int i = 0; i < max_attempts; ++i) {
351        auto head = cli.Head(parts.path);
352        bool head_ok = head && head->status >= 200 && head->status < 300;
353        if (!head_ok) {
354            LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
355            if (file_exists) {
356                LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
357                return 304; // 304 Not Modified - fake cached response
358            }
359            return head->status; // cannot use cached file, return raw status code
360            // TODO: maybe retry only on certain codes
361        }
362
363        std::string etag;
364        if (head_ok && head->has_header("ETag")) {
365            etag = head->get_header_value("ETag");
366        }
367
368        size_t total_size = 0;
369        if (head_ok && head->has_header("Content-Length")) {
370            try {
371                total_size = std::stoull(head->get_header_value("Content-Length"));
372            } catch (const std::exception& e) {
373                LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
374            }
375        }
376
377        bool supports_ranges = false;
378        if (head_ok && head->has_header("Accept-Ranges")) {
379            supports_ranges = head->get_header_value("Accept-Ranges") != "none";
380        }
381
382        bool should_download_from_scratch = false;
383        if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
384            LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
385                    last_etag.c_str(), etag.c_str());
386            should_download_from_scratch = true;
387        }
388
389        if (file_exists) {
390            if (!should_download_from_scratch) {
391                LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
392                return 304; // 304 Not Modified - fake cached response
393            }
394            LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
395            if (remove(path.c_str()) != 0) {
396                LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
397                return -1;
398            }
399        }
400
401        const std::string path_temporary = path + ".downloadInProgress";
402        size_t existing_size = 0;
403
404        if (std::filesystem::exists(path_temporary)) {
405            if (supports_ranges && !should_download_from_scratch) {
406                existing_size = std::filesystem::file_size(path_temporary);
407            } else if (remove(path_temporary.c_str()) != 0) {
408                LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
409                return -1;
410            }
411        }
412
413        // start the download
414        LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
415                __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
416        const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
417        if (!was_pull_successful) {
418            if (i + 1 < max_attempts) {
419                const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
420                LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
421                std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
422            } else {
423                LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
424            }
425            continue;
426        }
427
428        if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
429            LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
430            return -1;
431        }
432        if (!etag.empty()) {
433            write_etag(path, etag);
434        }
435
436        return head->status; // TODO: use actual GET status?
437    }
438
439    return -1; // max attempts reached
440}
441
442std::pair<long, std::vector<char>> common_remote_get_content(const std::string          & url,
443                                                             const common_remote_params & params) {
444    auto [cli, parts] = common_http_client(url);
445
446    httplib::Headers headers;
447    for (const auto & h : params.headers) {
448        headers.emplace(h.first, h.second);
449    }
450    if (headers.find("User-Agent") == headers.end()) {
451        headers.emplace("User-Agent", "llama-cpp/" + build_info);
452    }
453
454    if (params.timeout > 0) {
455        cli.set_read_timeout(params.timeout, 0);
456        cli.set_write_timeout(params.timeout, 0);
457    }
458
459    std::vector<char> buf;
460    auto res = cli.Get(parts.path, headers,
461        [&](const char *data, size_t len) {
462            buf.insert(buf.end(), data, data + len);
463            return params.max_size == 0 ||
464                   buf.size() <= static_cast<size_t>(params.max_size);
465        },
466        nullptr
467    );
468
469    if (!res) {
470        throw std::runtime_error("error: cannot make GET request");
471    }
472
473    return { res->status, std::move(buf) };
474}
475
476int common_download_file_single(const std::string & url,
477                                const std::string & path,
478                                const std::string & bearer_token,
479                                bool offline,
480                                const common_header_list & headers) {
481    if (!offline) {
482        return common_download_file_single_online(url, path, bearer_token, headers);
483    }
484
485    if (!std::filesystem::exists(path)) {
486        LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
487        return -1;
488    }
489
490    LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
491    return 304; // Not Modified - fake cached response
492}
493
494// download multiple files from remote URLs to local paths
495// the input is a vector of pairs <url, path>
496static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
497                                          const std::string & bearer_token,
498                                          bool offline,
499                                          const common_header_list & headers) {
500    // Prepare download in parallel
501    std::vector<std::future<bool>> futures_download;
502    futures_download.reserve(urls.size());
503
504    for (auto const & item : urls) {
505        futures_download.push_back(
506            std::async(
507                std::launch::async,
508                [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
509                    const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
510                    return is_http_status_ok(http_status);
511                },
512                item
513            )
514        );
515    }
516
517    // Wait for all downloads to complete
518    for (auto & f : futures_download) {
519        if (!f.get()) {
520            return false;
521        }
522    }
523
524    return true;
525}
526
527bool common_download_model(const common_params_model & model,
528                           const std::string & bearer_token,
529                           bool offline,
530                           const common_header_list & headers) {
531    // Basic validation of the model.url
532    if (model.url.empty()) {
533        LOG_ERR("%s: invalid model url\n", __func__);
534        return false;
535    }
536
537    const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
538    if (!is_http_status_ok(http_status)) {
539        return false;
540    }
541
542    // check for additional GGUFs split to download
543    int n_split = 0;
544    {
545        struct gguf_init_params gguf_params = {
546            /*.no_alloc = */ true,
547            /*.ctx      = */ NULL,
548        };
549        auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
550        if (!ctx_gguf) {
551            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, model.path.c_str());
552            return false;
553        }
554
555        auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
556        if (key_n_split >= 0) {
557            n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
558        }
559
560        gguf_free(ctx_gguf);
561    }
562
563    if (n_split > 1) {
564        char split_prefix[PATH_MAX] = {0};
565        char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
566
567        // Verify the first split file format
568        // and extract split URL and PATH prefixes
569        {
570            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
571                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
572                return false;
573            }
574
575            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
576                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
577                return false;
578            }
579        }
580
581        std::vector<std::pair<std::string, std::string>> urls;
582        for (int idx = 1; idx < n_split; idx++) {
583            char split_path[PATH_MAX] = {0};
584            llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
585
586            char split_url[LLAMA_MAX_URL_LENGTH] = {0};
587            llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
588
589            if (std::string(split_path) == model.path) {
590                continue; // skip the already downloaded file
591            }
592
593            urls.push_back({split_url, split_path});
594        }
595
596        // Download in parallel
597        common_download_file_multiple(urls, bearer_token, offline, headers);
598    }
599
600    return true;
601}
602
603common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
604                                      const std::string & bearer_token,
605                                      bool offline,
606                                      const common_header_list & custom_headers) {
607    // the returned hf_repo is without tag
608    auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
609
610    std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
611
612    // headers
613    common_header_list headers = custom_headers;
614    headers.push_back({"Accept", "application/json"});
615    if (!bearer_token.empty()) {
616        headers.push_back({"Authorization", "Bearer " + bearer_token});
617    }
618    // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
619    // User-Agent header is already set in common_remote_get_content, no need to set it here
620
621    // make the request
622    common_remote_params params;
623    params.headers = headers;
624    long res_code = 0;
625    std::string res_str;
626    bool use_cache = false;
627    std::string cached_response_path = get_manifest_path(hf_repo, tag);
628    if (!offline) {
629        try {
630            auto res = common_remote_get_content(url, params);
631            res_code = res.first;
632            res_str = std::string(res.second.data(), res.second.size());
633        } catch (const std::exception & e) {
634            LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
635        }
636    }
637    if (res_code == 0) {
638        if (std::filesystem::exists(cached_response_path)) {
639            LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
640            res_str = read_file(cached_response_path);
641            res_code = 200;
642            use_cache = true;
643        } else {
644            throw std::runtime_error(
645                offline ? "error: failed to get manifest (offline mode)"
646                : "error: failed to get manifest (check your internet connection)");
647        }
648    }
649    std::string ggufFile;
650    std::string mmprojFile;
651
652    if (res_code == 200 || res_code == 304) {
653        try {
654            auto j = json::parse(res_str);
655
656            if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
657                ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
658            }
659            if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
660                mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
661            }
662        } catch (const std::exception & e) {
663            throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
664        }
665        if (!use_cache) {
666            // if not using cached response, update the cache file
667            write_file(cached_response_path, res_str);
668        }
669    } else if (res_code == 401) {
670        throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
671    } else {
672        throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
673    }
674
675    // check response
676    if (ggufFile.empty()) {
677        throw std::runtime_error("error: model does not have ggufFile");
678    }
679
680    return { hf_repo, ggufFile, mmprojFile };
681}
682
683//
684// Docker registry functions
685//
686
687static std::string common_docker_get_token(const std::string & repo) {
688    std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
689
690    common_remote_params params;
691    auto                 res = common_remote_get_content(url, params);
692
693    if (res.first != 200) {
694        throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
695    }
696
697    std::string            response_str(res.second.begin(), res.second.end());
698    nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
699
700    if (!response.contains("token")) {
701        throw std::runtime_error("Docker registry token response missing 'token' field");
702    }
703
704    return response["token"].get<std::string>();
705}
706
707std::string common_docker_resolve_model(const std::string & docker) {
708    // Parse ai/smollm2:135M-Q4_0
709    size_t      colon_pos = docker.find(':');
710    std::string repo, tag;
711    if (colon_pos != std::string::npos) {
712        repo = docker.substr(0, colon_pos);
713        tag  = docker.substr(colon_pos + 1);
714    } else {
715        repo = docker;
716        tag  = "latest";
717    }
718
719    // ai/ is the default
720    size_t      slash_pos = docker.find('/');
721    if (slash_pos == std::string::npos) {
722        repo.insert(0, "ai/");
723    }
724
725    LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
726    try {
727        // --- helper: digest validation ---
728        auto validate_oci_digest = [](const std::string & digest) -> std::string {
729            // Expected: algo:hex ; start with sha256 (64 hex chars)
730            // You can extend this map if supporting other algorithms in future.
731            static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
732            std::smatch m;
733            if (!std::regex_match(digest, m, re)) {
734                throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
735            }
736            // normalize hex to lowercase
737            std::string normalized = digest;
738            std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
739                return std::tolower(c);
740            });
741            return normalized;
742        };
743
744        std::string token = common_docker_get_token(repo);  // Get authentication token
745
746        // Get manifest
747        // TODO: cache the manifest response so that it appears in the model list
748        const std::string    url_prefix = "https://registry-1.docker.io/v2/" + repo;
749        std::string          manifest_url = url_prefix + "/manifests/" + tag;
750        common_remote_params manifest_params;
751        manifest_params.headers.push_back({"Authorization", "Bearer " + token});
752        manifest_params.headers.push_back({"Accept",
753            "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
754        });
755        auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
756        if (manifest_res.first != 200) {
757            throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
758        }
759
760        std::string            manifest_str(manifest_res.second.begin(), manifest_res.second.end());
761        nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
762        std::string            gguf_digest;  // Find the GGUF layer
763        if (manifest.contains("layers")) {
764            for (const auto & layer : manifest["layers"]) {
765                if (layer.contains("mediaType")) {
766                    std::string media_type = layer["mediaType"].get<std::string>();
767                    if (media_type == "application/vnd.docker.ai.gguf.v3" ||
768                        media_type.find("gguf") != std::string::npos) {
769                        gguf_digest = layer["digest"].get<std::string>();
770                        break;
771                    }
772                }
773            }
774        }
775
776        if (gguf_digest.empty()) {
777            throw std::runtime_error("No GGUF layer found in Docker manifest");
778        }
779
780        // Validate & normalize digest
781        gguf_digest = validate_oci_digest(gguf_digest);
782        LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
783
784        // Prepare local filename
785        std::string model_filename = repo;
786        std::replace(model_filename.begin(), model_filename.end(), '/', '_');
787        model_filename += "_" + tag + ".gguf";
788        std::string local_path = fs_get_cache_file(model_filename);
789
790        const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
791        const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
792        if (!is_http_status_ok(http_status)) {
793            throw std::runtime_error("Failed to download Docker Model");
794        }
795
796        LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
797        return local_path;
798    } catch (const std::exception & e) {
799        LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
800        throw;
801    }
802}
803
804#else
805
806common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
807    throw std::runtime_error("download functionality is not enabled in this build");
808}
809
810bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
811    throw std::runtime_error("download functionality is not enabled in this build");
812}
813
814std::string common_docker_resolve_model(const std::string &) {
815    throw std::runtime_error("download functionality is not enabled in this build");
816}
817
818int common_download_file_single(const std::string &,
819                                const std::string &,
820                                const std::string &,
821                                bool,
822                                const common_header_list &) {
823    throw std::runtime_error("download functionality is not enabled in this build");
824}
825
826#endif // defined(LLAMA_USE_HTTPLIB)
827
828std::vector<common_cached_model_info> common_list_cached_models() {
829    std::vector<common_cached_model_info> models;
830    const std::string cache_dir = fs_get_cache_directory();
831    const std::vector<common_file_info> files = fs_list(cache_dir, false);
832    for (const auto & file : files) {
833        if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
834            common_cached_model_info model_info;
835            model_info.manifest_path = file.path;
836            std::string fname = file.name;
837            string_replace_all(fname, ".json", ""); // remove extension
838            auto parts = string_split<std::string>(fname, '=');
839            if (parts.size() == 4) {
840                // expect format: manifest=<user>=<model>=<tag>=<other>
841                model_info.user  = parts[1];
842                model_info.model = parts[2];
843                model_info.tag   = parts[3];
844            } else {
845                // invalid format
846                continue;
847            }
848            model_info.size = 0; // TODO: get GGUF size, not manifest size
849            models.push_back(model_info);
850        }
851    }
852    return models;
853}