enable f16; #ifdef DST_F32 #define DST_INNER_TYPE f32 #else #define DST_INNER_TYPE f16 #endif #ifdef VEC4 #define SRC_TYPE vec4 #define DST_TYPE vec4 #define VEC_SIZE 4 #else #define SRC_TYPE f32 #define DST_TYPE DST_INNER_TYPE #define VEC_SIZE 1 #endif @group(0) @binding(0) var src: array; @group(0) @binding(1) var idx: array; @group(0) @binding(2) var dst: array; #ifdef I64_IDX @group(0) @binding(3) var error: atomic; #define PARAMS_BINDING 4 #else #define PARAMS_BINDING 3 #endif struct Params { offset_src: u32, // in elements offset_idx: u32, // in elements offset_dst: u32, // in elements // Strides (in elements) stride_src1: u32, stride_src2: u32, stride_src3: u32, stride_idx0: u32, stride_idx1: u32, stride_idx2: u32, stride_dst1: u32, stride_dst2: u32, stride_dst3: u32, // Shape of src ne0: u32, n_rows: u32, ne2: u32, ne3: u32, // Shape of idx idx1: u32, idx2: u32, }; @group(0) @binding(PARAMS_BINDING) var params: Params; @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) { return; } // getting the row from gid let elems_per_row = params.ne0 / VEC_SIZE; var i = gid.x / elems_per_row; let i_src3 = i / (params.ne2 * params.n_rows); i = i % (params.ne2 * params.n_rows); let i_src2 = i / params.n_rows; let i_src1 = i % params.n_rows; let i_idx2 = i_src3 % params.idx2; let i_idx1 = i_src2 % params.idx1; let i_idx0 = i_src1; #ifdef I64_IDX let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; let idx_val = idx[idx_high]; let idx_low_val = idx[idx_high + 1]; if (idx_low_val != 0) { // Upper bits of index are not zero, output will be incorrect atomicStore(&error, 1); return; } #else let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; let idx_val = idx[idx_i]; #endif let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; let col_idx = (gid.x % elems_per_row); dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); }