1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
 2
 3//------------------------------------------------------------------------------
 4// diag_mask_inf kernels
 5//------------------------------------------------------------------------------
 6kernel void kernel_diag_mask_inf(
 7        global float * src0,
 8        ulong offset0,
 9        global float * dst,
10        ulong offsetd,
11        int ne00,
12        int ne01,
13        int n_past
14) {
15    src0 = (global float*)((global char*)src0 + offset0);
16    dst = (global float*)((global char*)dst + offsetd);
17
18    int i02 = get_global_id(2);
19    int i01 = get_global_id(1);
20    int i00 = get_global_id(0);
21
22    if (i00 > n_past + i01) {
23        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
24    } else {
25        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
26    }
27}
28
29kernel void kernel_diag_mask_inf_8(
30        global float4 * src0,
31        ulong offset0,
32        global float4 * dst,
33        ulong offsetd,
34        int ne00,
35        int ne01,
36        int n_past
37) {
38    src0 = (global float4*)((global char*)src0 + offset0);
39    dst = (global float4*)((global char*)dst + offsetd);
40
41    int i = 2*get_global_id(0);
42
43    dst[i+0] = src0[i+0];
44    dst[i+1] = src0[i+1];
45    int i4 = 4*i;
46    int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
47    int i01 = i4/(ne00);      i4 -= i01*ne00;
48    int i00 = i4;
49    for (int k = 3; k >= 0; --k) {
50        if (i00 + 4 + k <= n_past + i01) {
51            break;
52        }
53        (&dst[i+1])[k] = -INFINITY;
54        if (i00 + k > n_past + i01) {
55            (&dst[i])[k] = -INFINITY;
56        }
57    }
58}