1#extension GL_EXT_shader_16bit_storage : require
 2#extension GL_EXT_control_flow_attributes : require
 3
 4#include "rte.glsl"
 5#include "utils.glsl"
 6#if RMS_NORM_ROPE_FUSION
 7#include "rope_params.glsl"
 8#endif
 9
10layout (push_constant) uniform parameter
11{
12    uint ne;
13    uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
14    uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
15    uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
16    uint misalign_offsets;
17    float param1; float param2; int param3;
18#if RMS_NORM_ROPE_FUSION
19    rope_params rope;
20#endif
21} p;
22
23#if !RMS_NORM_ROPE_FUSION
24layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
25#if defined(A_TYPE_PACKED16)
26layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
27#endif
28#if defined(A_TYPE_PACKED32)
29layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
30#endif
31
32layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
33layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
34#endif
35
36// true if src0/src1 are the same shape and the indices can be reused without additional modulus
37layout(constant_id = 0) const bool norepeat = false;
38
39uint get_idx() {
40    return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
41}
42
43uint get_aoffset() { return p.misalign_offsets >> 16; }
44uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
45uint get_doffset() { return p.misalign_offsets & 0xFF; }
46
47
48void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
49    get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
50}
51
52uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
53    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
54}
55
56uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
57    if (norepeat) {
58        return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
59    } else {
60        return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
61    }
62}
63
64uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
65    return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
66}