1enable f16;
2
3#ifdef DST_F32
4#define DST_INNER_TYPE f32
5#else
6#define DST_INNER_TYPE f16
7#endif
8
9#ifdef VEC4
10#define SRC_TYPE vec4<f32>
11#define DST_TYPE vec4<DST_INNER_TYPE>
12#define VEC_SIZE 4
13#else
14#define SRC_TYPE f32
15#define DST_TYPE DST_INNER_TYPE
16#define VEC_SIZE 1
17#endif
18
19@group(0) @binding(0)
20var<storage, read_write> src: array<SRC_TYPE>;
21
22@group(0) @binding(1)
23var<storage, read_write> idx: array<u32>;
24
25@group(0) @binding(2)
26var<storage, read_write> dst: array<DST_TYPE>;
27
28#ifdef I64_IDX
29@group(0) @binding(3)
30var<storage, read_write> error: atomic<u32>;
31#define PARAMS_BINDING 4
32#else
33#define PARAMS_BINDING 3
34#endif
35
36struct Params {
37 offset_src: u32, // in elements
38 offset_idx: u32, // in elements
39 offset_dst: u32, // in elements
40
41 // Strides (in elements)
42 stride_src1: u32,
43 stride_src2: u32,
44 stride_src3: u32,
45
46 stride_idx0: u32,
47 stride_idx1: u32,
48 stride_idx2: u32,
49
50 stride_dst1: u32,
51 stride_dst2: u32,
52 stride_dst3: u32,
53
54 // Shape of src
55 ne0: u32,
56 n_rows: u32,
57 ne2: u32,
58 ne3: u32,
59
60 // Shape of idx
61 idx1: u32,
62 idx2: u32,
63};
64
65@group(0) @binding(PARAMS_BINDING)
66var<uniform> params: Params;
67
68@compute @workgroup_size(WG_SIZE)
69fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
70 if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
71 return;
72 }
73
74 // getting the row from gid
75 let elems_per_row = params.ne0 / VEC_SIZE;
76 var i = gid.x / elems_per_row;
77
78 let i_src3 = i / (params.ne2 * params.n_rows);
79
80 i = i % (params.ne2 * params.n_rows);
81 let i_src2 = i / params.n_rows;
82 let i_src1 = i % params.n_rows;
83
84 let i_idx2 = i_src3 % params.idx2;
85 let i_idx1 = i_src2 % params.idx1;
86 let i_idx0 = i_src1;
87
88#ifdef I64_IDX
89 let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
90
91 let idx_val = idx[idx_high];
92 let idx_low_val = idx[idx_high + 1];
93
94 if (idx_low_val != 0) {
95 // Upper bits of index are not zero, output will be incorrect
96 atomicStore(&error, 1);
97 return;
98 }
99#else
100 let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
101 let idx_val = idx[idx_i];
102#endif
103
104 let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
105 let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
106
107 let col_idx = (gid.x % elems_per_row);
108 dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
109}