1#include "common.cuh"
  2#include "cross-entropy-loss.cuh"
  3#include "sum.cuh"
  4
  5#include <cmath>
  6#include <cstdint>
  7
  8template <bool use_shared>
  9static __global__ void cross_entropy_loss_f32(
 10        const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
 11    extern __shared__ float tmp[];
 12
 13    logits += int64_t(blockIdx.x)*nclasses;
 14    labels += int64_t(blockIdx.x)*nclasses;
 15
 16    // Find maximum for softmax:
 17    float max_logit = -INFINITY;
 18    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
 19        const float val = logits[i];
 20        max_logit = fmaxf(max_logit, val);
 21
 22        if (use_shared) {
 23            tmp[i] = val;
 24        }
 25    }
 26    max_logit = warp_reduce_max(max_logit);
 27
 28    // Calculate log(softmax(logits)) which is just logits - max:
 29    float sum = 0.0f;
 30    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
 31        const float logit_i = use_shared ? tmp[i] : logits[i];
 32        sum += expf(logit_i - max_logit);
 33    }
 34    sum = warp_reduce_sum(sum);
 35    sum = logf(sum);
 36
 37    // log(exp(logits - max) / sum) = (logits - max) - log(sum)
 38    float loss = 0.0f;
 39    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
 40        const float logit_i = use_shared ? tmp[i] : logits[i];
 41        loss += (logit_i - max_logit - sum) * labels[i];
 42    }
 43    loss = -warp_reduce_sum(loss) / (float)k;
 44
 45    if (threadIdx.x != 0) {
 46        return;
 47    }
 48
 49    dst[blockIdx.x] = loss;
 50}
 51
 52template <bool use_shared>
 53static __global__ void cross_entropy_loss_back_f32(
 54        const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
 55        float * __restrict__ dst, const int nclasses) {
 56    extern __shared__ float tmp[];
 57
 58    logits += int64_t(blockIdx.x)*nclasses;
 59    labels += int64_t(blockIdx.x)*nclasses;
 60    dst    += int64_t(blockIdx.x)*nclasses;
 61
 62    float maxval = -INFINITY;
 63    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
 64        const float val = logits[i];
 65        maxval = fmaxf(maxval, val);
 66
 67        if (use_shared) {
 68            tmp[i] = val;
 69        }
 70    }
 71    maxval = warp_reduce_max(maxval);
 72
 73    float sum = 0.0f;
 74    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
 75        const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
 76        sum += val;
 77
 78        if (use_shared) {
 79            tmp[i] = val;
 80        } else {
 81            dst[i] = val;
 82        }
 83    }
 84    sum = warp_reduce_sum(sum);
 85    const float sm_scale = 1.0f/sum;
 86
 87    const float d_by_nrows = *grad/gridDim.x;
 88    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
 89        const float val = use_shared ? tmp[i] : dst[i];
 90        dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
 91    }
 92}
 93
 94void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 95    const ggml_tensor * src0 = dst->src[0];
 96    const ggml_tensor * src1 = dst->src[1];
 97
 98    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 99    GGML_ASSERT(src1->type == GGML_TYPE_F32);
100    GGML_ASSERT( dst->type == GGML_TYPE_F32);
101
102    GGML_ASSERT(ggml_is_contiguous(src0));
103    GGML_ASSERT(ggml_is_contiguous(src1));
104    GGML_ASSERT(ggml_is_contiguous(dst));
105
106    const int64_t ne00  = src0->ne[0];
107    const int64_t nrows = ggml_nrows(src0);
108
109    const float * src0_d = (const float *) src0->data;
110    const float * src1_d = (const float *) src1->data;
111    float       * dst_d  = (float       *) dst->data;
112
113    ggml_cuda_pool & pool = ctx.pool();
114    cudaStream_t stream = ctx.stream();
115
116    const dim3 blocks_dim(WARP_SIZE, 1, 1);
117    const dim3 blocks_num(nrows, 1, 1);
118    const size_t nbytes_shared = ne00*sizeof(float);
119
120    const int    id    = ggml_cuda_get_device();
121    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
122
123    ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124
125    if (nbytes_shared <= smpbo) {
126        CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
127        cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
128    } else {
129        cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
130    }
131    CUDA_CHECK(cudaGetLastError());
132
133    // Combine results from individual blocks:
134    sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
135}
136
137void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
138    const ggml_tensor * grad  = dst->src[0];
139    const ggml_tensor * src0f = dst->src[1];
140    const ggml_tensor * src1f = dst->src[2];
141
142    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
143    GGML_ASSERT(src1f->type == GGML_TYPE_F32);
144    GGML_ASSERT( grad->type == GGML_TYPE_F32);
145    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
146
147    GGML_ASSERT(ggml_is_scalar(grad));
148    GGML_ASSERT(ggml_is_contiguous(src0f));
149    GGML_ASSERT(ggml_is_contiguous(src1f));
150    GGML_ASSERT(ggml_is_contiguous(dst));
151    GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
152    GGML_ASSERT(ggml_are_same_shape(src0f, dst));
153
154    const int64_t ne00  = src0f->ne[0];
155    const int64_t nrows = ggml_nrows(src0f);
156
157    const float * grad_d  = (const float *) grad->data;
158    const float * src0f_d = (const float *) src0f->data;
159    const float * src1f_d = (const float *) src1f->data;
160    float       * dst_d   = (float       *) dst->data;
161
162    cudaStream_t stream = ctx.stream();
163
164    const dim3 blocks_dim(WARP_SIZE, 1, 1);
165    const dim3 blocks_num(nrows, 1, 1);
166    const size_t nbytes_shared = ne00*sizeof(float);
167
168    const int    id    = ggml_cuda_get_device();
169    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
170
171    if (nbytes_shared <= smpbo) {
172        CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
173        cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
174    } else {
175        cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
176    }
177}