1@group(0) @binding(0)
 2var<storage, read_write> output_buffer: array<u32>;
 3
 4struct Params {
 5    offset: u32, // in bytes
 6    size: u32,   // in bytes
 7    value: u32,  // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
 8};
 9
10@group(0) @binding(1)
11var<uniform> params: Params;
12
13override wg_size: u32;
14override bytes_per_thread: u32;
15
16@compute @workgroup_size(wg_size)
17fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
18    let i = gid.x * bytes_per_thread;
19    let start = params.offset;
20    let end = params.offset + params.size;
21
22    for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
23        let byte_index = start + i + j;
24        if (byte_index + 4 <= end) {
25            output_buffer[byte_index >> 2] = params.value;
26        } else {
27            // Handle tail (unaligned)
28            for (var k: u32 = 0; k < 4; k++) {
29                let idx = byte_index + k;
30                if (idx < end) {
31                    let word_idx = idx >> 2;
32                    let bit_offset = (idx & 3) * 8u;
33                    let mask = ~(0xffu << bit_offset);
34                    let existing = output_buffer[word_idx];
35                    output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
36                }
37            }
38        }
39    }
40}