blob: 90d703f273cea27ed0d3413acb3baaa48d969c00 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7
8#include <xnnpack/gemm.h>
9#include <xnnpack/math.h>
10
11
Marat Dukhan436ebe62019-12-04 15:10:12 -080012$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
13$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
14void xnn_f32_gemm${"inc" if INC else ""}_ukernel_${MR}x${NR}__${"wasm" if WASM else "scalar"}(
XNNPACK Teamb455b122019-09-27 18:10:33 -070015 size_t mr,
16 size_t nc,
17 size_t kc,
18 const float* restrict a,
19 size_t a_stride,
20 const float* restrict w,
21 float* restrict c,
22 size_t cm_stride,
23 size_t cn_stride,
24 $if INC:
25 const float*restrict acc,
26 const union xnn_f32_output_params params[restrict static 1])
27{
28 assert(mr != 0);
29 assert(mr <= ${MR});
30 assert(nc != 0);
31 assert(kc != 0);
32 assert(kc % sizeof(float) == 0);
33 assert(a != NULL);
34 assert(w != NULL);
35 assert(c != NULL);
36 $if INC:
37 assert(acc != NULL);
38
39 const float* a0 = a;
40 float* c0 = c;
41 $for M in range(1, MR):
42 const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
43 float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
44 $if M % 2 == 0:
45 if XNN_UNPREDICTABLE(mr <= ${M}) {
46 a${M} = a${M-1};
47 c${M} = c${M-1};
48 }
49 $elif M + 1 == MR:
50 if XNN_UNPREDICTABLE(mr != ${M+1}) {
51 a${M} = a${M-1};
52 c${M} = c${M-1};
53 }
54 $else:
55 if XNN_UNPREDICTABLE(mr < ${M+1}) {
56 a${M} = a${M-1};
57 c${M} = c${M-1};
58 }
59
60 do {
61 $if INC:
62 $for M in range(MR):
63 $for N in range(NR):
64 float vacc${M}${N} = acc[${M*NR+N}];
65 acc += ${MR*NR};
66 $else:
67 $for N in range(NR):
68 float vacc0${N} = w[${N}];
69 w += ${NR};
70 $for M in range(1, MR):
71 $for N in range(NR):
72 float vacc${M}${N} = vacc0${N};
73
74 size_t k = kc;
75 do {
76 $for M in range(MR):
77 const float va${M} = *a${M}++;
78
79 $for N in range(NR):
80 const float vb${N} = w[${N}];
81 w += ${NR};
82
83 $for M in range(MR):
84 $for N in range(NR):
85 vacc${M}${N} += va${M} * vb${N};
86
87 k -= sizeof(float);
88 } while (k != 0);
89
90 const float vmin = params->scalar.min;
91 $for M in range(MR):
92 $for N in range(NR):
Marat Dukhan436ebe62019-12-04 15:10:12 -080093 vacc${M}${N} = ${MAX_F32}(vacc${M}${N}, vmin);
XNNPACK Teamb455b122019-09-27 18:10:33 -070094
95 const float vmax = params->scalar.max;
96 $for M in range(MR):
97 $for N in range(NR):
Marat Dukhan436ebe62019-12-04 15:10:12 -080098 vacc${M}${N} = ${MIN_F32}(vacc${M}${N}, vmax);
XNNPACK Teamb455b122019-09-27 18:10:33 -070099
100 if XNN_LIKELY(nc >= ${NR}) {
101 $for M in reversed(range(MR)):
102 $for N in range(NR):
103 c${M}[${N}] = vacc${M}${N};
104 c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
105
106 $for M in reversed(range(MR)):
107 a${M} = (const void*) ((uintptr_t) a${M} - kc);
108
109 nc -= ${NR};
110 } else {
111 $for LOG2N in reversed(range(NR.bit_length() - 1)):
112 if (nc & ${1 << LOG2N}) {
113 $for M in reversed(range(MR)):
114 $for N in range(1 << LOG2N):
115 c${M}[${N}] = vacc${M}${N};
116 $if LOG2N != 0:
117 $for N in range(1 << (LOG2N - 1)):
118 vacc${M}${N} = vacc${M}${N + (1 << LOG2N)};
119 c${M} += ${1 << LOG2N};
120 }
121
122 nc = 0;
123 }
124 } while (nc != 0);
125}