1#include "ggml-opt.h"
   2
   3#include "ggml.h"
   4#include "ggml-alloc.h"
   5#include "ggml-backend.h"
   6#include "ggml-impl.h"
   7
   8#include <algorithm>
   9#include <cmath>
  10#include <cstdint>
  11#include <cinttypes>
  12#include <map>
  13#include <random>
  14#include <vector>
  15
  16struct ggml_opt_dataset {
  17    struct ggml_context   * ctx    = nullptr;
  18    ggml_backend_buffer_t   buf    = nullptr;
  19    struct ggml_tensor    * data   = nullptr;
  20    struct ggml_tensor    * labels = nullptr;
  21
  22    int64_t ndata       = -1;
  23    int64_t ndata_shard = -1;
  24    size_t  nbs_data    = -1;
  25    size_t  nbs_labels  = -1;
  26
  27    std::vector<int64_t> permutation;
  28};
  29
  30struct ggml_opt_context {
  31    ggml_backend_sched_t       backend_sched        = nullptr;
  32    ggml_cgraph              * allocated_graph      = nullptr;
  33    ggml_cgraph              * allocated_graph_copy = nullptr;
  34    struct ggml_context      * ctx_static           = nullptr;
  35    struct ggml_context      * ctx_cpu              = nullptr;
  36    struct ggml_context      * ctx_compute          = nullptr;
  37    struct ggml_context      * ctx_copy             = nullptr;
  38    ggml_backend_buffer_t      buf_static           = nullptr;
  39    ggml_backend_buffer_t      buf_cpu              = nullptr;
  40    std::mt19937               rng;
  41    enum ggml_opt_loss_type    loss_type;
  42    enum ggml_opt_build_type   build_type;
  43    enum ggml_opt_build_type   build_type_alloc;
  44
  45    struct ggml_tensor * inputs  = nullptr;
  46    struct ggml_tensor * outputs = nullptr;
  47    struct ggml_tensor * labels  = nullptr;
  48
  49    struct ggml_tensor * loss     = nullptr;
  50    struct ggml_tensor * pred     = nullptr;
  51    struct ggml_tensor * ncorrect = nullptr;
  52
  53    struct ggml_cgraph * gf      = nullptr;
  54    struct ggml_cgraph * gb_grad = nullptr;
  55    struct ggml_cgraph * gb_opt  = nullptr;
  56    bool static_graphs           = false;
  57    bool eval_ready              = false;
  58    std::vector<struct ggml_tensor *> grad_accs;
  59    std::vector<struct ggml_tensor *> grad_m;
  60    std::vector<struct ggml_tensor *> grad_v;
  61
  62    int64_t iter               = 1;
  63    int32_t opt_period         = 1;
  64    int32_t opt_i              = 0;
  65    bool    loss_per_datapoint = false;
  66
  67    ggml_opt_get_optimizer_params get_opt_pars    = nullptr;
  68    void *                        get_opt_pars_ud = nullptr;
  69    struct ggml_tensor *          opt_step_params = nullptr; // Stores output of get_opt_pars.
  70
  71    enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
  72};
  73
  74struct ggml_opt_result {
  75    int64_t              ndata    = 0;
  76    std::vector<float>   loss;
  77    std::vector<int32_t> pred;
  78    int64_t              ncorrect = 0;
  79
  80    int64_t opt_period         = -1;
  81    bool    loss_per_datapoint = false;
  82};
  83
  84// ====== Dataset ======
  85
  86ggml_opt_dataset_t ggml_opt_dataset_init(
  87        enum ggml_type type_data,
  88        enum ggml_type type_label,
  89        int64_t        ne_datapoint,
  90        int64_t        ne_label,
  91        int64_t        ndata,
  92        int64_t        ndata_shard) {
  93    GGML_ASSERT(ne_datapoint >  0);
  94    GGML_ASSERT(ne_label     >= 0);
  95    GGML_ASSERT(ndata        >  0);
  96    GGML_ASSERT(ndata_shard  >  0);
  97
  98    ggml_opt_dataset_t result = new ggml_opt_dataset;
  99    result->ndata       = ndata;
 100    result->ndata_shard = ndata_shard;
 101
 102    {
 103        struct ggml_init_params params = {
 104            /*.mem_size   =*/ 2*ggml_tensor_overhead(),
 105            /*.mem_buffer =*/ nullptr,
 106            /*.no_alloc   =*/ true,
 107        };
 108        result->ctx = ggml_init(params);
 109    }
 110
 111    result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
 112    result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
 113
 114    if (ne_label > 0) {
 115        result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
 116        result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
 117    } else {
 118        result->labels = nullptr;
 119        result->nbs_labels = 0;
 120    }
 121
 122    result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());
 123
 124    const int64_t nshards = ndata/ndata_shard;
 125    result->permutation.resize(nshards);
 126    for (int64_t i = 0; i < nshards; ++i) {
 127        result->permutation[i] = i;
 128    }
 129    return result;
 130}
 131
 132void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
 133    ggml_backend_buffer_free(dataset->buf);
 134    ggml_free(dataset->ctx);
 135    delete dataset;
 136}
 137
 138int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
 139    return dataset->ndata;
 140}
 141
 142struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
 143    return dataset->data;
 144}
 145
 146struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {
 147    return dataset->labels;
 148}
 149
 150void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {
 151    GGML_ASSERT(idata <= dataset->ndata);
 152
 153    if (idata < 0) {
 154        std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
 155        return;
 156    }
 157
 158    GGML_ASSERT(idata % dataset->ndata_shard == 0);
 159    const int64_t ishard_max = idata / dataset->ndata_shard;
 160    std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
 161}
 162
 163void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {
 164    GGML_ASSERT(   data_batch && ggml_is_contiguous(data_batch));
 165    GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
 166    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
 167    GGML_ASSERT(                   data_batch->type == dataset->data->type);
 168    GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
 169
 170    const size_t nb_data_batch = ggml_nbytes(data_batch);
 171    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
 172    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
 173
 174    if (labels_batch) {
 175        const size_t nb_labels_batch = ggml_nbytes(labels_batch);
 176        GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
 177    }
 178
 179    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
 180
 181    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
 182        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
 183
 184        const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
 185        ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
 186
 187        if (!labels_batch) {
 188            continue;
 189        }
 190
 191        const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
 192        ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
 193    }
 194}
 195
 196void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
 197    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
 198    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
 199
 200    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
 201
 202    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
 203
 204    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
 205        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
 206
 207        const char * ptr_data       = (const char *) dataset->data->data + ishard      *dataset->nbs_data;
 208        char       * ptr_data_batch = (char       *) data_batch          + ishard_batch*dataset->nbs_data;
 209        memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
 210
 211        if (!labels_batch) {
 212            continue;
 213        }
 214
 215        const char * ptr_labels       = (const char *) dataset->labels->data + ishard      *dataset->nbs_labels;
 216        char       * ptr_labels_batch = (char       *) labels_batch          + ishard_batch*dataset->nbs_labels;
 217        memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
 218    }
 219}
 220
 221// ====== Model / Context ======
 222
 223struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
 224    GGML_UNUSED(userdata);
 225
 226    ggml_opt_optimizer_params result;
 227
 228    result.adamw.alpha = 0.001f;
 229    result.adamw.beta1 = 0.9f;
 230    result.adamw.beta2 = 0.999f;
 231    result.adamw.eps   = 1e-8f;
 232    result.adamw.wd    = 0.0f;
 233
 234    result.sgd.alpha   = 1e-3f;
 235    result.sgd.wd      = 0.0f;
 236
 237    return result;
 238}
 239
 240
 241struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
 242    return *((struct ggml_opt_optimizer_params *) userdata);
 243}
 244
 245struct ggml_opt_params ggml_opt_default_params(
 246        ggml_backend_sched_t      backend_sched,
 247        enum ggml_opt_loss_type   loss_type) {
 248    return {
 249        /*backend_sched   =*/ backend_sched,
 250        /*ctx_compute     =*/ nullptr,
 251        /*inputs          =*/ nullptr,
 252        /*logits          =*/ nullptr,
 253        /*loss_type       =*/ loss_type,
 254        /*build_type      =*/ GGML_OPT_BUILD_TYPE_OPT,
 255        /*opt_period      =*/ 1,
 256        /*get_opt_pars    =*/ ggml_opt_get_default_optimizer_params,
 257        /*get_opt_pars_ud =*/ nullptr,
 258        /*optimizer       =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
 259    };
 260}
 261
 262static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {
 263    if (!tensor) {
 264        return nullptr;
 265    }
 266
 267    if (tensor_map.find(tensor) != tensor_map.end()) {
 268        return tensor_map[tensor];
 269    }
 270
 271    ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);
 272    tensor_map[tensor] = new_tensor;
 273
 274    new_tensor->op = tensor->op;
 275    for (int i = 0; i < GGML_MAX_DIMS; i++) {
 276        new_tensor->nb[i] = tensor->nb[i];
 277    }
 278    new_tensor->flags = tensor->flags;
 279    memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
 280    strcpy(new_tensor->name, tensor->name);
 281    new_tensor->data = tensor->data;
 282    new_tensor->buffer = tensor->buffer;
 283    new_tensor->extra = tensor->extra;
 284    new_tensor->view_offs = tensor->view_offs;
 285    new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
 286    for (int i = 0; i < GGML_MAX_SRC; i++) {
 287        new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
 288    }
 289
 290    return new_tensor;
 291}
 292
 293static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
 294    std::map<ggml_tensor *, ggml_tensor *> tensor_map;
 295
 296    ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
 297
 298    for (int i = 0; i < src->n_leafs; i++) {
 299        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
 300    }
 301    GGML_ASSERT(dst->n_leafs == src->n_leafs);
 302    for (int i = 0; i < src->n_nodes; i++) {
 303        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
 304    }
 305    GGML_ASSERT(dst->n_nodes == src->n_nodes);
 306    for (int i = 0; i < src->n_nodes; ++i) {
 307        const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
 308        const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
 309
 310        GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
 311        GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
 312        GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
 313        GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
 314
 315        dst->grads[igrad_dst]     = src->grads[igrad_src];
 316        dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
 317    }
 318
 319    return dst;
 320}
 321
 322static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
 323    GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
 324    GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
 325
 326    const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
 327
 328    const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
 329        !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
 330
 331    const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
 332        opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
 333
 334    ggml_set_input(opt_ctx->inputs);
 335    ggml_set_output(opt_ctx->outputs);
 336
 337    int n_param = 0;
 338    for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
 339        const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
 340        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
 341            n_param++;
 342        }
 343        GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
 344    }
 345
 346    if (!opt_ctx->ctx_static) {
 347        // The static context is used for:
 348        //   - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
 349        //   - optimizer momenta (2 tensors per param)
 350        //   - labels (if using static graphs)
 351        //   - loss (if using static graphs, up to 5 tensors)
 352        //   - pred (if using static graphs)
 353        //   - ncorrect (if using static graphs, 2 tensors).
 354        constexpr size_t n_loss = 1;
 355        const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
 356        const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
 357        const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
 358        struct ggml_init_params params = {
 359            /*.mem_size   =*/ size_meta,
 360            /*.mem_buffer =*/ nullptr,
 361            /*.no_alloc   =*/ true,
 362        };
 363        opt_ctx->ctx_static = ggml_init(params);
 364    }
 365    GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
 366
 367    {
 368        // The cpu context is allocated statically if using static graphs, dynamically otherwise.
 369        // It is used for:
 370        //   - optimizer parameters (1 shared for all optimizer invocations)
 371        const size_t size_meta = 1 * ggml_tensor_overhead();
 372        struct ggml_init_params params = {
 373            /*.mem_size   =*/ size_meta,
 374            /*.mem_buffer =*/ nullptr,
 375            /*.no_alloc   =*/ true,
 376        };
 377        ggml_free(opt_ctx->ctx_cpu);
 378        opt_ctx->ctx_cpu = ggml_init(params);
 379
 380        ggml_backend_buffer_free(opt_ctx->buf_cpu);
 381        opt_ctx->buf_cpu = nullptr;
 382    }
 383
 384    struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
 385
 386    switch (opt_ctx->loss_type) {
 387        case GGML_OPT_LOSS_TYPE_MEAN: {
 388            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
 389            ggml_set_name(opt_ctx->loss, "loss_sum");
 390            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
 391            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
 392            ggml_set_name(opt_ctx->loss, "loss_mean");
 393            opt_ctx->loss_per_datapoint = true;
 394            break;
 395        }
 396        case GGML_OPT_LOSS_TYPE_SUM: {
 397            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
 398            ggml_set_name(opt_ctx->loss, "loss_sum");
 399            opt_ctx->loss_per_datapoint = false;
 400            break;
 401        }
 402        case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
 403            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
 404            ggml_set_input(opt_ctx->labels);
 405            ggml_set_name(opt_ctx->labels, "labels");
 406            opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
 407            ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
 408            if (opt_ctx->opt_period > 1) {
 409                opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
 410                ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
 411            }
 412            opt_ctx->loss_per_datapoint = true;
 413            break;
 414        }
 415        case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
 416            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
 417            ggml_set_input(opt_ctx->labels);
 418            ggml_set_name(opt_ctx->labels, "labels");
 419            opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
 420            ggml_set_name(opt_ctx->loss, "loss_error");
 421            opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
 422            ggml_set_name(opt_ctx->loss, "loss_squared_error");
 423            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
 424            ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
 425            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
 426            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
 427            ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
 428            opt_ctx->loss_per_datapoint = true;
 429            break;
 430        }
 431    }
 432    ggml_set_output(opt_ctx->loss);
 433    ggml_set_loss(opt_ctx->loss);
 434    ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
 435
 436    if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
 437        opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
 438        ggml_set_name(opt_ctx->pred, "pred");
 439        ggml_set_output(opt_ctx->pred);
 440        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
 441
 442        opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
 443        ggml_set_name(opt_ctx->ncorrect, "ncorrect");
 444        ggml_set_output(opt_ctx->ncorrect);
 445        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
 446    }
 447
 448    if (opt_ctx->buf_static) {
 449        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
 450            return;
 451        }
 452    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
 453        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
 454            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
 455        return;
 456    }
 457
 458    if (opt_ctx->grad_accs.empty()) {
 459        GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
 460
 461        const int n_nodes = opt_ctx->gf->n_nodes;
 462        opt_ctx->grad_accs.resize(n_nodes);
 463        for (int i = 0; i < n_nodes; ++i) {
 464            ggml_tensor * node = opt_ctx->gf->nodes[i];
 465            if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
 466                opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
 467            } else {
 468                opt_ctx->grad_accs[i] = nullptr;
 469            }
 470        }
 471
 472        if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
 473            opt_ctx->grad_m.resize(n_nodes);
 474            opt_ctx->grad_v.resize(n_nodes);
 475            for (int i = 0; i < n_nodes; ++i) {
 476                ggml_tensor * node = opt_ctx->gf->nodes[i];
 477                if (node->flags & GGML_TENSOR_FLAG_PARAM) {
 478                    opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
 479                    opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
 480                } else {
 481                    opt_ctx->grad_m[i] = nullptr;
 482                    opt_ctx->grad_v[i] = nullptr;
 483                }
 484            }
 485        }
 486    }
 487
 488    // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
 489    opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
 490    ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
 491
 492    if (opt_ctx->buf_static) {
 493        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
 494            return;
 495        }
 496    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
 497        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
 498        ggml_graph_reset(opt_ctx->gb_grad);
 499    }
 500
 501    GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
 502
 503    // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
 504    opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
 505
 506    opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
 507    ggml_tensor * adamw_params = opt_ctx->opt_step_params;
 508    ggml_set_input(adamw_params);
 509    const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
 510    ggml_format_name(adamw_params, "%s_params", optimizer_name);
 511    for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
 512        struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
 513        struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
 514
 515        if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
 516            struct ggml_tensor * m = nullptr;
 517            struct ggml_tensor * v = nullptr;
 518            if (need_momenta) {
 519                m = opt_ctx->grad_m[i];
 520                v = opt_ctx->grad_v[i];
 521                ggml_format_name(m, "AdamW m for %s", node->name);
 522                ggml_format_name(v, "AdamW v for %s", node->name);
 523            }
 524            struct ggml_tensor * opt_step;
 525            switch (optimizer) {
 526                case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
 527                    opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
 528                    break;
 529                case GGML_OPT_OPTIMIZER_TYPE_SGD:
 530                    opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
 531                    break;
 532                default:
 533                    GGML_ABORT("fatal error");
 534            }
 535            ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
 536            ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
 537        }
 538    }
 539
 540    if (!opt_ctx->buf_static) {
 541        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
 542            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
 543        ggml_graph_reset(opt_ctx->gb_opt);
 544    }
 545
 546    opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
 547}
 548
 549ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
 550    ggml_opt_context_t result = new struct ggml_opt_context;
 551    result->backend_sched    = params.backend_sched;
 552    result->ctx_compute      = params.ctx_compute;
 553    result->loss_type        = params.loss_type;
 554    result->build_type       = params.build_type;
 555    result->build_type_alloc = params.build_type;
 556    result->inputs           = params.inputs;
 557    result->outputs          = params.outputs;
 558    result->opt_period       = params.opt_period;
 559    result->get_opt_pars     = params.get_opt_pars;
 560    result->get_opt_pars_ud  = params.get_opt_pars_ud;
 561    result->optimizer        = params.optimizer;
 562
 563    GGML_ASSERT(result->opt_period >= 1);
 564
 565    result->static_graphs = result->ctx_compute;
 566
 567    if (!result->static_graphs) {
 568        GGML_ASSERT(!result->inputs);
 569        GGML_ASSERT(!result->outputs);
 570        return result;
 571    }
 572
 573    GGML_ASSERT(result->inputs);
 574    GGML_ASSERT(result->outputs);
 575
 576    result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
 577    ggml_build_forward_expand(result->gf, result->outputs);
 578
 579    ggml_opt_build(result);
 580
 581    return result;
 582}
 583
 584void ggml_opt_free(ggml_opt_context_t opt_ctx) {
 585    if (opt_ctx == nullptr) {
 586        return;
 587    }
 588    ggml_backend_buffer_free(opt_ctx->buf_static);
 589    ggml_backend_buffer_free(opt_ctx->buf_cpu);
 590    ggml_free(opt_ctx->ctx_static);
 591    ggml_free(opt_ctx->ctx_cpu);
 592    delete opt_ctx;
 593}
 594
 595void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
 596    if (optimizer) {
 597        ggml_graph_reset(opt_ctx->gb_opt);
 598        opt_ctx->iter = 1;
 599    } else {
 600        ggml_graph_reset(opt_ctx->gb_grad);
 601    }
 602}
 603
 604bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
 605    return opt_ctx->static_graphs;
 606}
 607
 608struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
 609    return opt_ctx->inputs;
 610}
 611
 612struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {
 613    return opt_ctx->outputs;
 614}
 615
 616struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {
 617    return opt_ctx->labels;
 618}
 619
 620struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {
 621    return opt_ctx->loss;
 622}
 623
 624struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {
 625    return opt_ctx->pred;
 626}
 627
 628struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {
 629    return opt_ctx->ncorrect;
 630}
 631
 632struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {
 633    return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
 634}
 635
 636// ====== Optimization Result ======
 637
 638ggml_opt_result_t ggml_opt_result_init() {
 639    return new ggml_opt_result;
 640}
 641
 642void ggml_opt_result_free(ggml_opt_result_t result) {
 643    delete result;
 644}
 645
 646void ggml_opt_result_reset(ggml_opt_result_t result) {
 647    result->ndata = 0;
 648    result->loss.clear();
 649    result->pred.clear();
 650    result->ncorrect = 0;
 651}
 652
 653void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {
 654    *ndata = result->ndata;
 655}
 656
 657void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {
 658    const int64_t nbatches = result->loss.size(); // Number of physical batches.
 659
 660    if (nbatches == 0) {
 661        *loss = 0.0;
 662        *unc  = NAN;
 663        return;
 664    }
 665
 666    double sum         = 0.0;
 667    double sum_squared = 0.0;
 668
 669    for (const float & loss : result->loss) {
 670        // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
 671        const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
 672        sum         += loss_scaled;
 673        sum_squared += loss_scaled*loss_scaled;
 674    }
 675
 676    const double mean = sum/nbatches;
 677    *loss = result->loss_per_datapoint ? mean : sum;
 678
 679    if (!unc) {
 680        return;
 681    }
 682
 683    if (nbatches < 2) {
 684        *unc = NAN;
 685        return;
 686    }
 687
 688    const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
 689    *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
 690}
 691
 692void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {
 693    for (size_t i = 0; i < result->pred.size(); ++i) {
 694        pred[i] = result->pred[i];
 695    }
 696}
 697
 698void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {
 699    *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
 700
 701    if (!unc) {
 702        return;
 703    }
 704
 705    *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
 706        sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
 707}
 708
 709// ====== Computation ======
 710
 711void ggml_opt_prepare_alloc(
 712        ggml_opt_context_t    opt_ctx,
 713        struct ggml_context * ctx_compute,
 714        struct ggml_cgraph  * gf,
 715        struct ggml_tensor  * inputs,
 716        struct ggml_tensor  * outputs) {
 717    GGML_ASSERT(!opt_ctx->static_graphs);
 718    opt_ctx->ctx_compute = ctx_compute;
 719    opt_ctx->gf          = gf;
 720    opt_ctx->inputs      = inputs;
 721    opt_ctx->outputs     = outputs;
 722}
 723
 724void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
 725    GGML_ASSERT(!opt_ctx->eval_ready);
 726    if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
 727        ggml_graph_reset(opt_ctx->gb_grad);
 728    }
 729    if (backward) {
 730        const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
 731        opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
 732    } else {
 733        opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
 734    }
 735
 736    if (!opt_ctx->static_graphs) {
 737        ggml_opt_build(opt_ctx);
 738    }
 739
 740    struct ggml_cgraph * graph = nullptr;
 741    switch (opt_ctx->build_type) {
 742        case GGML_OPT_BUILD_TYPE_FORWARD: {
 743            graph = opt_ctx->gf;
 744        } break;
 745        case GGML_OPT_BUILD_TYPE_GRAD: {
 746            graph = opt_ctx->gb_grad;
 747        } break;
 748        case GGML_OPT_BUILD_TYPE_OPT: {
 749            graph = opt_ctx->gb_opt;
 750        } break;
 751    }
 752    GGML_ASSERT(graph);
 753
 754    if (opt_ctx->allocated_graph == graph) {
 755        opt_ctx->eval_ready = true;
 756        return;
 757    }
 758
 759    ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
 760
 761    if (opt_ctx->static_graphs) {
 762        ggml_init_params params = {
 763            /*.mem_size   =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
 764            /*.mem_buffer =*/ nullptr,
 765            /*.no_alloc   =*/ true,
 766        };
 767        ggml_free(opt_ctx->ctx_copy);
 768        opt_ctx->ctx_copy = ggml_init(params);
 769
 770        opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
 771    } else {
 772        opt_ctx->allocated_graph_copy = graph;
 773    }
 774
 775    ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
 776    opt_ctx->allocated_graph = graph;
 777
 778    opt_ctx->eval_ready = true;
 779}
 780
 781void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
 782    GGML_ASSERT(opt_ctx->eval_ready);
 783    if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
 784        const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
 785
 786        switch (opt_ctx->optimizer) {
 787            case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
 788                GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
 789                GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
 790                GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
 791                GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
 792                GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
 793                GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
 794                GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
 795                GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
 796
 797                // beta1, beta2 after applying warmup
 798                const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
 799                const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
 800
 801                float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
 802                adamw_par_data[0] = opt_pars.adamw.alpha;
 803                adamw_par_data[1] = opt_pars.adamw.beta1;
 804                adamw_par_data[2] = opt_pars.adamw.beta2;
 805                adamw_par_data[3] = opt_pars.adamw.eps;
 806                adamw_par_data[4] = opt_pars.adamw.wd;
 807                adamw_par_data[5] = beta1h;
 808                adamw_par_data[6] = beta2h;
 809            } break;
 810            case GGML_OPT_OPTIMIZER_TYPE_SGD: {
 811                GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
 812                GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
 813                GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
 814                float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
 815                sgd[0] = opt_pars.sgd.alpha;
 816                sgd[1] = opt_pars.sgd.wd;
 817            } break;
 818            default:
 819                GGML_ABORT("fatal error");
 820        }
 821    }
 822
 823    ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
 824    opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
 825    opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
 826
 827    if (!opt_ctx->static_graphs) {
 828        opt_ctx->gf                   = nullptr;
 829        opt_ctx->gb_grad              = nullptr;
 830        opt_ctx->gb_opt               = nullptr;
 831        opt_ctx->allocated_graph      = nullptr;
 832        opt_ctx->allocated_graph_copy = nullptr;
 833    }
 834
 835    opt_ctx->eval_ready = false;
 836
 837    if (!result) {
 838        return;
 839    }
 840
 841    if (result->ndata == 0) {
 842        result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
 843        result->opt_period         = opt_ctx->opt_period;
 844    } else {
 845        GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
 846        GGML_ASSERT(result->opt_period         == opt_ctx->opt_period);
 847    }
 848
 849    const int64_t ndata = opt_ctx->outputs->ne[1];
 850    GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
 851    result->ndata += ndata;
 852
 853    GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));
 854    GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);
 855    float loss;
 856    ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
 857    result->loss.push_back(loss);
 858
 859    if (opt_ctx->pred) {
 860        GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
 861        std::vector<int32_t> pred(ndata);
 862        ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
 863        result->pred.insert(result->pred.end(), pred.begin(), pred.end());
 864    }
 865
 866    if (!opt_ctx->ncorrect || result->ncorrect < 0) {
 867        result->ncorrect = -1;
 868        return;
 869    }
 870
 871    GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));
 872    GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);
 873    int64_t ncorrect;
 874    ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));
 875    result->ncorrect += ncorrect;
 876}
 877
 878// ====== High-Level Functions ======
 879
 880void ggml_opt_epoch(
 881        ggml_opt_context_t      opt_ctx,
 882        ggml_opt_dataset_t      dataset,
 883        ggml_opt_result_t       result_train,
 884        ggml_opt_result_t       result_eval,
 885        int64_t                 idata_split,
 886        ggml_opt_epoch_callback callback_train,
 887        ggml_opt_epoch_callback callback_eval) {
 888    GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
 889    struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
 890    struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
 891    struct ggml_tensor * data   = ggml_opt_dataset_data(dataset);
 892    GGML_ASSERT(data->ne[0] == inputs->ne[0]);
 893
 894    const int64_t ndata       =   data->ne[1];
 895    const int64_t ndata_batch = inputs->ne[1];
 896
 897    GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
 898    const int64_t nbatches = ndata/ndata_batch;
 899
 900    idata_split = idata_split < 0 ? ndata : idata_split;
 901    GGML_ASSERT(idata_split % ndata_batch == 0);
 902    const int64_t ibatch_split = idata_split / ndata_batch;
 903
 904    int64_t ibatch = 0;
 905    int64_t t_loop_start = ggml_time_us();
 906    for (; ibatch < ibatch_split; ++ibatch) {
 907        ggml_opt_alloc(opt_ctx, /*backward =*/ true);
 908        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
 909        ggml_opt_eval(opt_ctx, result_train);
 910        if (callback_train) {
 911            callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
 912        }
 913    }
 914    t_loop_start = ggml_time_us();
 915    for (; ibatch < nbatches; ++ibatch) {
 916        ggml_opt_alloc(opt_ctx, /*backward =*/ false);
 917        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
 918        ggml_opt_eval(opt_ctx, result_eval);
 919        if (callback_eval) {
 920            callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
 921        }
 922    }
 923}
 924
 925void ggml_opt_epoch_callback_progress_bar(
 926        bool               train,
 927        ggml_opt_context_t opt_ctx,
 928        ggml_opt_dataset_t dataset,
 929        ggml_opt_result_t  result,
 930        int64_t            ibatch,
 931        int64_t            ibatch_max,
 932        int64_t            t_start_us) {
 933    fprintf(stderr, "%s[", train ? "train: " : "val:   ");
 934
 935    // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
 936    constexpr int64_t bar_length = 8;
 937    const int64_t ibatch8 = 8 * ibatch;
 938    for (int64_t j = 0; j < bar_length; ++j) {
 939        if        (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
 940            fprintf(stderr, "\u2588"); // full block
 941        } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
 942            fprintf(stderr, "\u2589"); // 7/8 filled
 943        } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
 944            fprintf(stderr, "\u258A"); // 6/8 filled
 945        } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
 946            fprintf(stderr, "\u258B"); // 5/8 filled
 947        } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
 948            fprintf(stderr, "\u258C"); // 4/8 filled
 949        } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
 950            fprintf(stderr, "\u258D"); // 3/8 filled
 951        } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
 952            fprintf(stderr, "\u258E"); // 2/8 filled
 953        } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
 954            fprintf(stderr, "\u258F"); // 1/8 filled
 955        } else {
 956            fprintf(stderr, " ");
 957        }
 958    }
 959
 960    const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];
 961    const int64_t idata      = ibatch*batch_size;
 962    const int64_t idata_max  = ibatch_max*batch_size;
 963
 964    double loss;
 965    double loss_unc;
 966    ggml_opt_result_loss(result, &loss, &loss_unc);
 967
 968    double accuracy;
 969    double accuracy_unc;
 970    ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
 971
 972    const int64_t t_ibatch_us = ggml_time_us() - t_start_us;
 973    int64_t t_ibatch_s = t_ibatch_us / 1000000;
 974    const int64_t t_ibatch_h = t_ibatch_s / 3600;
 975    t_ibatch_s -= t_ibatch_h * 3600;
 976    const int64_t t_ibatch_m = t_ibatch_s / 60;
 977    t_ibatch_s -= t_ibatch_m * 60;
 978
 979    const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
 980    int64_t t_eta_s = t_eta_us / 1000000;
 981    const int64_t t_eta_h = t_eta_s / 3600;
 982    t_eta_s -= t_eta_h * 3600;
 983    const int64_t t_eta_m = t_eta_s / 60;
 984    t_eta_s -= t_eta_m * 60;
 985
 986    fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
 987            "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
 988            idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
 989            t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
 990    if (ibatch == ibatch_max) {
 991        fprintf(stderr, "\n");
 992    }
 993    fflush(stderr);
 994
 995    GGML_UNUSED(dataset);
 996}
 997
 998void ggml_opt_fit(
 999        ggml_backend_sched_t            backend_sched,
1000        ggml_context                  * ctx_compute,
1001        ggml_tensor                   * inputs,
1002        ggml_tensor                   * outputs,
1003        ggml_opt_dataset_t              dataset,
1004        enum ggml_opt_loss_type         loss_type,
1005        enum ggml_opt_optimizer_type    optimizer,
1006        ggml_opt_get_optimizer_params   get_opt_pars,
1007        int64_t                         nepoch,
1008        int64_t                         nbatch_logical,
1009        float                           val_split,
1010        bool                            silent) {
1011    ggml_time_init();
1012    const int64_t t_start_us = ggml_time_us();
1013
1014    const int64_t ndata           = ggml_opt_dataset_data(dataset)->ne[1];
1015    const int64_t nbatch_physical = inputs->ne[1];
1016    GGML_ASSERT(ndata          % nbatch_logical  == 0);
1017    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
1018
1019    const int64_t opt_period       = nbatch_logical / nbatch_physical;
1020    const int64_t nbatches_logical = ndata / nbatch_logical;
1021
1022    GGML_ASSERT(val_split >= 0.0f);
1023    GGML_ASSERT(val_split <  1.0f);
1024    const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
1025    const int64_t idata_split  = ibatch_split * nbatch_physical;
1026
1027    int64_t epoch = 1;
1028
1029    ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
1030    params.ctx_compute     = ctx_compute;
1031    params.inputs          = inputs;
1032    params.outputs         = outputs;
1033    params.opt_period      = opt_period;
1034    params.get_opt_pars    = get_opt_pars;
1035    params.get_opt_pars_ud = &epoch;
1036    params.optimizer       = optimizer;
1037    ggml_opt_context_t opt_ctx = ggml_opt_init(params);
1038
1039    // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
1040    if (nbatch_logical < ndata) {
1041        ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
1042    }
1043
1044    ggml_opt_result_t result_train = ggml_opt_result_init();
1045    ggml_opt_result_t result_val   = ggml_opt_result_init();
1046
1047    ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;
1048
1049    for (; epoch <= nepoch; ++epoch) {
1050        if (nbatch_logical < idata_split) {
1051            ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
1052        }
1053
1054        ggml_opt_result_reset(result_train);
1055        ggml_opt_result_reset(result_val);
1056
1057        if (!silent) {
1058            fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
1059        }
1060        ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
1061        if (!silent) {
1062            fprintf(stderr, "\n");
1063        }
1064    }
1065
1066    if (!silent) {
1067        int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;
1068        const int64_t t_total_h = t_total_s / 3600;
1069        t_total_s -= t_total_h * 3600;
1070        const int64_t t_total_m = t_total_s / 60;
1071        t_total_s -= t_total_m * 60;
1072        fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
1073    }
1074
1075    ggml_opt_free(opt_ctx);
1076    ggml_opt_result_free(result_train);
1077    ggml_opt_result_free(result_val);
1078}
1079
1080enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
1081    return c->optimizer;
1082}
1083
1084GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
1085    switch (o) {
1086        case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1087            return "adamw";
1088        case GGML_OPT_OPTIMIZER_TYPE_SGD:
1089            return "sgd";
1090        default:
1091            return "undefined";
1092    };
1093}