1
  2float rope_yarn_ramp(const float low, const float high, const uint i0) {
  3    const float y = (i0 / 2 - low) / max(0.001f, high - low);
  4    return 1.0f - min(1.0f, max(0.0f, y));
  5}
  6
  7uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
  8#if RMS_NORM_ROPE_FUSION
  9    // Per-row offset in shared memory
 10    const uint ix = i0;
 11#else
 12    const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
 13#endif
 14    return ix;
 15}
 16
 17void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
 18    float mscale = p.attn_factor;
 19    // Get n-d rotational scaling corrected for extrapolation
 20    float theta_interp = p.freq_scale * theta_extrap;
 21    float theta = theta_interp;
 22    if (p.ext_factor != 0.0f) {
 23        float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
 24        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 25
 26        // Get n-d magnitude scaling corrected for interpolation
 27        mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
 28    }
 29    // Backprogagation uses inverted rotation
 30    if (p.is_back != 0) {
 31        theta = -theta;
 32    }
 33    cos_theta = cos(theta) * mscale;
 34    sin_theta = sin(theta) * mscale;
 35}
 36
 37void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
 38    if (i0 >= p.ne00) {
 39        return;
 40    }
 41
 42    uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
 43    const uint ix = rope_a_coord(i0, i1, i2, i3, p);
 44
 45    // Fusion optimization: ROPE + VIEW + SET_ROWS.
 46    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
 47    if (p.set_rows_stride != 0) {
 48        idst = i1*p.nb11 + i0;
 49        idst += rope_data_i[i2].x * p.set_rows_stride;
 50    }
 51
 52    if (i0 >= p.n_dims) {
 53        rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
 54        rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);
 55
 56        return;
 57    }
 58
 59    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 60
 61    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 62
 63    float cos_theta, sin_theta;
 64    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
 65
 66    const float x0 = float(rope_data_a[ix + 0]);
 67    const float x1 = float(rope_data_a[ix + 1]);
 68
 69    rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
 70    rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 71}
 72
 73void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
 74    if (i0 >= p.ne00) {
 75        return;
 76    }
 77
 78    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
 79    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 80
 81    // Fusion optimization: ROPE + VIEW + SET_ROWS.
 82    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
 83    if (p.set_rows_stride != 0) {
 84        idst = i1*p.nb11 + i0/2;
 85        idst += rope_data_i[i2].x * p.set_rows_stride;
 86    }
 87
 88    if (i0 >= p.n_dims) {
 89        rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
 90        rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
 91
 92        return;
 93    }
 94
 95    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 96
 97    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 98
 99    float cos_theta, sin_theta;
100    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
101
102    const float x0 = float(rope_data_a[ix + 0]);
103    const float x1 = float(rope_data_a[ix + p.n_dims/2]);
104
105    rope_data_d[idst + 0]          = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
106    rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
107}
108
109
110void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
111    if (i0 >= p.ne00) {
112        return;
113    }
114
115    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
116    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
117
118    // Fusion optimization: ROPE + VIEW + SET_ROWS.
119    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
120    if (p.set_rows_stride != 0) {
121        idst = i1*p.nb11 + i0/2;
122        idst += rope_data_i[i2].x * p.set_rows_stride;
123    }
124
125    if (i0 >= p.n_dims) {
126        rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
127        rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
128
129        return;
130    }
131
132    const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
133    const int sec_w = p.sections[1] + p.sections[0];
134    const uint sector = (i0 / 2) % sect_dims;
135
136    float theta_base = 0.0;
137    if (p.is_imrope != 0) {
138        if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
139            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
140        } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
141            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
142        } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
143            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
144        } else {
145            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
146        }
147    } else {
148        if (sector < p.sections[0]) {
149            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
150        }
151        else if (sector >= p.sections[0] && sector < sec_w) {
152            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
153        }
154        else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
155            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
156        }
157        else if (sector >= sec_w + p.sections[2]) {
158            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
159        }
160    }
161
162    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
163
164    float cos_theta, sin_theta;
165    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
166
167    const float x0 = float(rope_data_a[ix + 0]);
168    const float x1 = float(rope_data_a[ix + p.n_dims/2]);
169
170    rope_data_d[idst + 0]          = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
171    rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
172}
173
174void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
175    if (i0 >= p.ne00) {
176        return;
177    }
178
179    const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
180    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
181
182    const int sect_dims = p.sections[0] + p.sections[1];
183    const int sec_w = p.sections[1] + p.sections[0];
184    const uint sector = (i0 / 2) % sect_dims;
185
186    float theta_base = 0.0;
187    if (sector < p.sections[0]) {
188        const uint p0 = sector;
189        theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
190    }
191    else if (sector >= p.sections[0] && sector < sec_w) {
192        const uint p0 = sector - p.sections[0];
193        theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
194    }
195
196    const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
197
198    float cos_theta, sin_theta;
199    rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
200
201    const float x0 = float(rope_data_a[ix + 0]);
202    const float x1 = float(rope_data_a[ix + p.n_dims]);
203
204    rope_data_d[idst + 0]        = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
205    rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
206}
207