1#include "common.cuh"
  2#include "wkv.cuh"
  3
  4template <int block_size>
  5static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
  6    const int tid = threadIdx.x;
  7    const int bid = blockIdx.x;
  8
  9    const int head_size = block_size;
 10    const int batch_i = bid / H;
 11    const int head_i = bid % H;
 12    const int state_size = C * head_size;
 13    const int n_seq_tokens = T / B;
 14
 15    float state[head_size];
 16    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
 17
 18    #pragma unroll
 19    for (int i = 0; i < head_size; i++) {
 20        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
 21    }
 22
 23    __syncthreads();
 24    _tf[tid] = tf[head_i * head_size + tid];
 25    __syncthreads();
 26
 27    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
 28        __syncthreads();
 29        _k[tid] = k[t];
 30        _r[tid] = r[t];
 31        _td[tid] = td[t];
 32        __syncthreads();
 33
 34        const float _v = v[t];
 35        float y = 0;
 36        for (int j = 0; j < head_size; j += 4) {
 37            const float4& k = (float4&)(_k[j]);
 38            const float4& r = (float4&)(_r[j]);
 39            const float4& tf = (float4&)(_tf[j]);
 40            const float4& td = (float4&)(_td[j]);
 41            float4& s = (float4&)(state[j]);
 42            float4 kv;
 43
 44            kv.x = k.x * _v;
 45            kv.y = k.y * _v;
 46            kv.z = k.z * _v;
 47            kv.w = k.w * _v;
 48
 49            y += r.x * (tf.x * kv.x + s.x);
 50            y += r.y * (tf.y * kv.y + s.y);
 51            y += r.z * (tf.z * kv.z + s.z);
 52            y += r.w * (tf.w * kv.w + s.w);
 53
 54            s.x = s.x * td.x + kv.x;
 55            s.y = s.y * td.y + kv.y;
 56            s.z = s.z * td.z + kv.z;
 57            s.w = s.w * td.w + kv.w;
 58        }
 59        dst[t] = y;
 60    }
 61
 62    #pragma unroll
 63    for (int i = 0; i < head_size; i++) {
 64        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
 65    }
 66}
 67
 68template <int block_size>
 69static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
 70    const int tid = threadIdx.x;
 71    const int bid = blockIdx.x;
 72
 73    const int head_size = block_size;
 74    const int batch_i = bid / H;
 75    const int head_i = bid % H;
 76    const int state_size = C * head_size;
 77    const int n_seq_tokens = T / B;
 78
 79    float state[head_size];
 80    __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
 81
 82#ifndef GGML_USE_MUSA
 83    #pragma unroll
 84#endif
 85    for (int i = 0; i < head_size; i++) {
 86        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
 87    }
 88
 89    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
 90        __syncthreads();
 91        _r[tid] = r[t];
 92        _w[tid] = w[t];
 93        _k[tid] = k[t];
 94        _a[tid] = a[t];
 95        _b[tid] = b[t];
 96        __syncthreads();
 97
 98        float sa = 0;
 99        #pragma unroll
100        for (int j = 0; j < head_size; j += 4)
101        {
102            const float4& a = (float4&)(_a[j]);
103            const float4& s = (float4&)(state[j]);
104            sa += a.x * s.x;
105            sa += a.y * s.y;
106            sa += a.z * s.z;
107            sa += a.w * s.w;
108        }
109
110        const float _v = v[t];
111        float y = 0;
112        for (int j = 0; j < head_size; j += 4) {
113            const float4& r = (float4&)(_r[j]);
114            const float4& w = (float4&)(_w[j]);
115            const float4& k = (float4&)(_k[j]);
116            const float4& b = (float4&)(_b[j]);
117            float4& s = (float4&)(state[j]);
118            float4 kv;
119
120            kv.x = k.x * _v;
121            kv.y = k.y * _v;
122            kv.z = k.z * _v;
123            kv.w = k.w * _v;
124
125            s.x = s.x * w.x + kv.x + sa * b.x;
126            s.y = s.y * w.y + kv.y + sa * b.y;
127            s.z = s.z * w.z + kv.z + sa * b.z;
128            s.w = s.w * w.w + kv.w + sa * b.w;
129
130            y += s.x * r.x;
131            y += s.y * r.y;
132            y += s.z * r.z;
133            y += s.w * r.w;
134        }
135        dst[t] = y;
136    }
137
138    #pragma unroll
139    for (int i = 0; i < head_size; i++) {
140        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
141    }
142}
143
144void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
145    const float * k_d  = (const float *)dst->src[0]->data;
146    const float * v_d  = (const float *)dst->src[1]->data;
147    const float * r_d  = (const float *)dst->src[2]->data;
148    const float * tf_d = (const float *)dst->src[3]->data;
149    const float * td_d = (const float *)dst->src[4]->data;
150    const float * s_d  = (const float *)dst->src[5]->data;
151
152    const int64_t B = dst->src[5]->ne[1];
153    const int64_t T = dst->src[0]->ne[2];
154    const int64_t C = dst->ne[0];
155    const int64_t H = dst->src[0]->ne[1];
156
157    float * dst_d = (float *)dst->data;
158
159    cudaStream_t stream = ctx.stream();
160
161    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
162    GGML_ASSERT(C % H == 0);
163    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
164
165    if (C / H == CUDA_WKV_BLOCK_SIZE) {
166        rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
167    } else {
168        rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
169    }
170}
171
172void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
173    const float * r_d = (const float *)dst->src[0]->data;
174    const float * w_d = (const float *)dst->src[1]->data;
175    const float * k_d = (const float *)dst->src[2]->data;
176    const float * v_d = (const float *)dst->src[3]->data;
177    const float * a_d = (const float *)dst->src[4]->data;
178    const float * b_d = (const float *)dst->src[5]->data;
179    const float * s_d = (const float *)dst->src[6]->data;
180
181    const int64_t B = dst->src[6]->ne[1];
182    const int64_t T = dst->src[0]->ne[2];
183    const int64_t C = dst->ne[0];
184    const int64_t H = dst->src[0]->ne[1];
185
186    float * dst_d = (float *)dst->data;
187
188    cudaStream_t stream = ctx.stream();
189
190    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
191    GGML_ASSERT(C % H == 0);
192    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
193
194    if (C / H == CUDA_WKV_BLOCK_SIZE) {
195        rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
196    } else {
197        rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
198    }
199}