1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
 2
 3//------------------------------------------------------------------------------
 4// solve_tri
 5//------------------------------------------------------------------------------
 6kernel void kernel_solve_tri_f32(
 7        global uchar * src0,
 8        ulong offset0,
 9        global uchar * src1,
10        ulong offset1,
11        global uchar * dst,
12        ulong offsetd,
13        int n,
14        int k,
15        ulong nb00,
16        ulong nb01,
17        ulong nb02,
18        ulong nb03,
19        ulong nb10,
20        ulong nb11,
21        ulong nb12,
22        ulong nb13,
23        ulong nb0,
24        ulong nb1,
25        ulong nb2,
26        ulong nb3
27) {
28    int col = get_global_id(0);
29    int i2 = get_global_id(1);
30    int i3 = get_global_id(2);
31
32    global const uchar * Lb = src0 + offset0 + i2 * nb02 + i3 * nb03;
33    global const uchar * Bb = src1 + offset1 + i2 * nb12 + i3 * nb13;
34    global       uchar * Xb = dst + offsetd + i2 * nb2 + i3 * nb3;
35
36    for(int row = 0; row < n; ++row){
37        global const float *pB = (global const float *)(Bb + row * nb11 + col * nb10);
38
39        float sum = 0.0f;
40        for(int j = 0; j < row; ++j){
41            global const float *pL = (global const float *)(Lb + row * nb01 + j * nb00);
42            global const float *pX = (global const float *)(Xb + j * nb1 + col * nb0);
43            sum += (*pL) * (*pX);
44        }
45
46        global const float * pDiag = (global const float *)(Lb + row * nb01 + row *nb00);
47        global float * pOut = (global float *)(Xb + row * nb1 + col *nb0);
48
49        *pOut = ((* pB) - sum) / (*pDiag);
50    }
51}