1
2float rope_yarn_ramp(const float low, const float high, const uint i0) {
3 const float y = (i0 / 2 - low) / max(0.001f, high - low);
4 return 1.0f - min(1.0f, max(0.0f, y));
5}
6
7uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
8#if RMS_NORM_ROPE_FUSION
9 // Per-row offset in shared memory
10 const uint ix = i0;
11#else
12 const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
13#endif
14 return ix;
15}
16
17void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
18 float mscale = p.attn_factor;
19 // Get n-d rotational scaling corrected for extrapolation
20 float theta_interp = p.freq_scale * theta_extrap;
21 float theta = theta_interp;
22 if (p.ext_factor != 0.0f) {
23 float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
24 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
25
26 // Get n-d magnitude scaling corrected for interpolation
27 mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
28 }
29 // Backprogagation uses inverted rotation
30 if (p.is_back != 0) {
31 theta = -theta;
32 }
33 cos_theta = cos(theta) * mscale;
34 sin_theta = sin(theta) * mscale;
35}
36
37void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
38 if (i0 >= p.ne00) {
39 return;
40 }
41
42 uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
43 const uint ix = rope_a_coord(i0, i1, i2, i3, p);
44
45 // Fusion optimization: ROPE + VIEW + SET_ROWS.
46 // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
47 if (p.set_rows_stride != 0) {
48 idst = i1*p.nb11 + i0;
49 idst += rope_data_i[i2].x * p.set_rows_stride;
50 }
51
52 if (i0 >= p.n_dims) {
53 rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
54 rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);
55
56 return;
57 }
58
59 const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
60
61 const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
62
63 float cos_theta, sin_theta;
64 rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
65
66 const float x0 = float(rope_data_a[ix + 0]);
67 const float x1 = float(rope_data_a[ix + 1]);
68
69 rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
70 rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
71}
72
73void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
74 if (i0 >= p.ne00) {
75 return;
76 }
77
78 uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
79 const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
80
81 // Fusion optimization: ROPE + VIEW + SET_ROWS.
82 // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
83 if (p.set_rows_stride != 0) {
84 idst = i1*p.nb11 + i0/2;
85 idst += rope_data_i[i2].x * p.set_rows_stride;
86 }
87
88 if (i0 >= p.n_dims) {
89 rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
90 rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
91
92 return;
93 }
94
95 const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
96
97 const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
98
99 float cos_theta, sin_theta;
100 rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
101
102 const float x0 = float(rope_data_a[ix + 0]);
103 const float x1 = float(rope_data_a[ix + p.n_dims/2]);
104
105 rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
106 rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
107}
108
109
110void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
111 if (i0 >= p.ne00) {
112 return;
113 }
114
115 uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
116 const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
117
118 // Fusion optimization: ROPE + VIEW + SET_ROWS.
119 // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
120 if (p.set_rows_stride != 0) {
121 idst = i1*p.nb11 + i0/2;
122 idst += rope_data_i[i2].x * p.set_rows_stride;
123 }
124
125 if (i0 >= p.n_dims) {
126 rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
127 rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);
128
129 return;
130 }
131
132 const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
133 const int sec_w = p.sections[1] + p.sections[0];
134 const uint sector = (i0 / 2) % sect_dims;
135
136 float theta_base = 0.0;
137 if (p.is_imrope != 0) {
138 if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
139 theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
140 } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
141 theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
142 } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
143 theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
144 } else {
145 theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
146 }
147 } else {
148 if (sector < p.sections[0]) {
149 theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
150 }
151 else if (sector >= p.sections[0] && sector < sec_w) {
152 theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
153 }
154 else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
155 theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
156 }
157 else if (sector >= sec_w + p.sections[2]) {
158 theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
159 }
160 }
161
162 const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
163
164 float cos_theta, sin_theta;
165 rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
166
167 const float x0 = float(rope_data_a[ix + 0]);
168 const float x1 = float(rope_data_a[ix + p.n_dims/2]);
169
170 rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
171 rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
172}
173
174void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
175 if (i0 >= p.ne00) {
176 return;
177 }
178
179 const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
180 const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
181
182 const int sect_dims = p.sections[0] + p.sections[1];
183 const int sec_w = p.sections[1] + p.sections[0];
184 const uint sector = (i0 / 2) % sect_dims;
185
186 float theta_base = 0.0;
187 if (sector < p.sections[0]) {
188 const uint p0 = sector;
189 theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
190 }
191 else if (sector >= p.sections[0] && sector < sec_w) {
192 const uint p0 = sector - p.sections[0];
193 theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
194 }
195
196 const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
197
198 float cos_theta, sin_theta;
199 rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);
200
201 const float x0 = float(rope_data_a[ix + 0]);
202 const float x1 = float(rope_data_a[ix + p.n_dims]);
203
204 rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
205 rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
206}
207