blob: 5db0f9c3c90e474497e62710ae7b8583055f2b72 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7
8#include <xnnpack/gemm.h>
9#include <xnnpack/math.h>
10
11
Marat Dukhan467f6362020-05-22 23:21:55 -070012$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"]
Marat Dukhan436ebe62019-12-04 15:10:12 -080013$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
14$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
Marat Dukhan163a7e62020-04-09 04:19:26 -070015$KERNEL = "gemminc" if INC else "gemm"
Marat Dukhan467f6362020-05-22 23:21:55 -070016$SUFFIX = {"LINEAR": "", "RELU": "_relu", "MINMAX": "_minmax"}[ACTIVATION]
17$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
Marat Dukhan163a7e62020-04-09 04:19:26 -070018void xnn_f32_${KERNEL}${SUFFIX}_ukernel_${MR}x${NR}__${"wasm" if WASM else "scalar"}(
XNNPACK Teamb455b122019-09-27 18:10:33 -070019 size_t mr,
20 size_t nc,
21 size_t kc,
22 const float* restrict a,
23 size_t a_stride,
24 const float* restrict w,
25 float* restrict c,
26 size_t cm_stride,
27 size_t cn_stride,
28 $if INC:
29 const float*restrict acc,
Marat Dukhanf196d012020-04-15 11:50:03 -070030 const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
XNNPACK Teamb455b122019-09-27 18:10:33 -070031{
32 assert(mr != 0);
33 assert(mr <= ${MR});
34 assert(nc != 0);
35 assert(kc != 0);
36 assert(kc % sizeof(float) == 0);
37 assert(a != NULL);
38 assert(w != NULL);
39 assert(c != NULL);
40 $if INC:
41 assert(acc != NULL);
42
43 const float* a0 = a;
44 float* c0 = c;
45 $for M in range(1, MR):
46 const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
47 float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
48 $if M % 2 == 0:
49 if XNN_UNPREDICTABLE(mr <= ${M}) {
50 a${M} = a${M-1};
51 c${M} = c${M-1};
52 }
53 $elif M + 1 == MR:
54 if XNN_UNPREDICTABLE(mr != ${M+1}) {
55 a${M} = a${M-1};
56 c${M} = c${M-1};
57 }
58 $else:
59 if XNN_UNPREDICTABLE(mr < ${M+1}) {
60 a${M} = a${M-1};
61 c${M} = c${M-1};
62 }
63
Marat Dukhan163a7e62020-04-09 04:19:26 -070064 $if ACTIVATION == "MINMAX":
65 const float vmin = params->scalar.min;
66 const float vmax = params->scalar.max;
XNNPACK Teamb455b122019-09-27 18:10:33 -070067 do {
68 $if INC:
69 $for M in range(MR):
70 $for N in range(NR):
71 float vacc${M}${N} = acc[${M*NR+N}];
72 acc += ${MR*NR};
73 $else:
74 $for N in range(NR):
75 float vacc0${N} = w[${N}];
76 w += ${NR};
77 $for M in range(1, MR):
78 $for N in range(NR):
79 float vacc${M}${N} = vacc0${N};
80
81 size_t k = kc;
82 do {
83 $for M in range(MR):
84 const float va${M} = *a${M}++;
85
86 $for N in range(NR):
87 const float vb${N} = w[${N}];
88 w += ${NR};
89
90 $for M in range(MR):
91 $for N in range(NR):
92 vacc${M}${N} += va${M} * vb${N};
93
94 k -= sizeof(float);
95 } while (k != 0);
96
Marat Dukhan163a7e62020-04-09 04:19:26 -070097 $if ACTIVATION == "MINMAX":
98 $for M in range(MR):
99 $for N in range(NR):
100 vacc${M}${N} = ${MAX_F32}(vacc${M}${N}, vmin);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700101
Marat Dukhan163a7e62020-04-09 04:19:26 -0700102 $for M in range(MR):
103 $for N in range(NR):
104 vacc${M}${N} = ${MIN_F32}(vacc${M}${N}, vmax);
Marat Dukhan467f6362020-05-22 23:21:55 -0700105 $elif ACTIVATION == "RELU":
106 $for M in range(MR):
107 $for N in range(NR):
108 vacc${M}${N} = ${MAX_F32}(vacc${M}${N}, 0.0f);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700109
110 if XNN_LIKELY(nc >= ${NR}) {
111 $for M in reversed(range(MR)):
112 $for N in range(NR):
113 c${M}[${N}] = vacc${M}${N};
114 c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
115
116 $for M in reversed(range(MR)):
117 a${M} = (const void*) ((uintptr_t) a${M} - kc);
118
119 nc -= ${NR};
120 } else {
121 $for LOG2N in reversed(range(NR.bit_length() - 1)):
122 if (nc & ${1 << LOG2N}) {
123 $for M in reversed(range(MR)):
124 $for N in range(1 << LOG2N):
125 c${M}[${N}] = vacc${M}${N};
126 $if LOG2N != 0:
127 $for N in range(1 << (LOG2N - 1)):
128 vacc${M}${N} = vacc${M}${N + (1 << LOG2N)};
129 c${M} += ${1 << LOG2N};
130 }
131
132 nc = 0;
133 }
134 } while (nc != 0);
135}