1#version 450
2
3#include "types.glsl"
4#include "generic_unary_head.glsl"
5
6#extension GL_EXT_control_flow_attributes : require
7
8const uint num_threads = 128;
9
10layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11
12void main() {
13 uint idx = get_idx();
14
15 // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
16 const uint num_iter = 4;
17
18 // fast path for when all four iterations are in-bounds
19 if (idx + (num_iter-1)*num_threads < p.ne) {
20 [[unroll]] for (uint i = 0; i < num_iter; ++i) {
21
22#if defined(DATA_D_BF16)
23 float f = float(data_a[get_aoffset() + idx]);
24 data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
25#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
26 data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
27#else
28 data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
29#endif
30 idx += num_threads;
31 }
32 } else {
33 [[unroll]] for (uint i = 0; i < num_iter; ++i) {
34 if (idx >= p.ne) {
35 continue;
36 }
37
38#if defined(DATA_D_BF16)
39 float f = float(data_a[get_aoffset() + idx]);
40 data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
41#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
42 data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
43#else
44 data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
45#endif
46 idx += num_threads;
47 }
48 }
49}