1@group(0) @binding(0)
2var<storage, read_write> output_buffer: array<u32>;
3
4struct Params {
5 offset: u32, // in bytes
6 size: u32, // in bytes
7 value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
8};
9
10@group(0) @binding(1)
11var<uniform> params: Params;
12
13override wg_size: u32;
14override bytes_per_thread: u32;
15
16@compute @workgroup_size(wg_size)
17fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
18 let i = gid.x * bytes_per_thread;
19 let start = params.offset;
20 let end = params.offset + params.size;
21
22 for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
23 let byte_index = start + i + j;
24 if (byte_index + 4 <= end) {
25 output_buffer[byte_index >> 2] = params.value;
26 } else {
27 // Handle tail (unaligned)
28 for (var k: u32 = 0; k < 4; k++) {
29 let idx = byte_index + k;
30 if (idx < end) {
31 let word_idx = idx >> 2;
32 let bit_offset = (idx & 3) * 8u;
33 let mask = ~(0xffu << bit_offset);
34 let existing = output_buffer[word_idx];
35 output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
36 }
37 }
38 }
39 }
40}