blob: 8e9699edefd38e5f3a4af5d1179c0a0b442f4302 [file] [log] [blame]
Marat Dukhanfda12b82019-11-21 12:27:59 -08001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
Marat Dukhan0f349c42019-11-27 11:58:54 -08006$assert NR % 8 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
Marat Dukhanfda12b82019-11-21 12:27:59 -08008#include <assert.h>
9
10#include <immintrin.h>
11
12#include <xnnpack/gemm.h>
13
14
15$ISA = {0: "avx", 3: "fma3"}[FMA]
Marat Dukhande06f492020-04-09 00:19:31 -070016void xnn_f32_gemm${"inc" if INC else ""}_minmax_ukernel_${MR}x${NR}__${ISA}_broadcast(
Marat Dukhanfda12b82019-11-21 12:27:59 -080017 size_t mr,
18 size_t nc,
19 size_t kc,
20 const float*restrict a,
21 size_t a_stride,
22 const float*restrict w,
23 float*restrict c,
24 size_t cm_stride,
25 size_t cn_stride,
26 $if INC:
27 const float*restrict acc,
Marat Dukhanf196d012020-04-15 11:50:03 -070028 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
Marat Dukhanfda12b82019-11-21 12:27:59 -080029{
30 assert(mr != 0);
31 assert(mr <= ${MR});
32 assert(nc != 0);
33 assert(kc != 0);
34 assert(kc % sizeof(float) == 0);
35 assert(a != NULL);
36 assert(w != NULL);
37 assert(c != NULL);
38 $if INC:
39 assert(acc != NULL);
40
41 const float* a0 = a;
42 float* c0 = c;
43 $for M in range(1, MR):
44 const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
45 float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
46 $if M % 2 == 0:
47 if XNN_UNPREDICTABLE(mr <= ${M}) {
48 a${M} = a${M-1};
49 c${M} = c${M-1};
50 }
51 $elif M + 1 == MR:
52 if XNN_UNPREDICTABLE(mr != ${M+1}) {
53 a${M} = a${M-1};
54 c${M} = c${M-1};
55 }
56 $else:
57 if XNN_UNPREDICTABLE(mr < ${M+1}) {
58 a${M} = a${M-1};
59 c${M} = c${M-1};
60 }
61
62 do {
63 $if INC:
64 $for M in range(MR):
65 $for N in range(0, NR, 8):
66 __m256 vacc${M}x${ABC[N:N+8]} = _mm256_load_ps(acc + ${M*NR+N});
67 acc += ${MR*NR};
68 $else:
69 $for N in range(0, NR, 8):
70 __m256 vacc0x${ABC[N:N+8]} = _mm256_load_ps(w + ${N});
71 $for M in range(1, MR):
72 $for N in range(0, NR, 8):
73 __m256 vacc${M}x${ABC[N:N+8]} = vacc0x${ABC[N:N+8]};
74 w += ${NR};
75
76 size_t k = kc;
77 do {
78 $for M in range(MR):
79 const __m256 va${M} = _mm256_broadcast_ss(a${M});
80 a${M} += 1;
81
82 const __m256 vb${ABC[0:8]} = _mm256_load_ps(w);
83 $for N in range(8, NR, 8):
84 const __m256 vb${ABC[N:N+8]} = _mm256_load_ps(w + ${N});
85 w += ${NR};
86
87 $for N in range(0, NR, 8):
88 $for M in range(MR):
89 $if FMA == 3:
90 vacc${M}x${ABC[N:N+8]} = _mm256_fmadd_ps(va${M}, vb${ABC[N:N+8]}, vacc${M}x${ABC[N:N+8]});
91 $else:
92 vacc${M}x${ABC[N:N+8]} = _mm256_add_ps(vacc${M}x${ABC[N:N+8]}, _mm256_mul_ps(va${M}, vb${ABC[N:N+8]}));
93
94 k -= sizeof(float);
95 } while (k != 0);
96
Marat Dukhan104ae5e2021-05-24 13:41:57 -070097 const __m256 vmin = _mm256_load_ps(params->avx.min);
Marat Dukhanfda12b82019-11-21 12:27:59 -080098 $for N in range(0, NR, 8):
99 $for M in range(MR):
100 vacc${M}x${ABC[N:N+8]} = _mm256_max_ps(vacc${M}x${ABC[N:N+8]}, vmin);
101
Marat Dukhan104ae5e2021-05-24 13:41:57 -0700102 const __m256 vmax = _mm256_load_ps(params->avx.max);
103 $for N in range(0, NR, 8):
104 $for M in range(MR):
105 vacc${M}x${ABC[N:N+8]} = _mm256_min_ps(vacc${M}x${ABC[N:N+8]}, vmax);
106
Marat Dukhanfda12b82019-11-21 12:27:59 -0800107 if XNN_LIKELY(nc >= ${NR}) {
108 $for M in reversed(range(MR)):
109 _mm256_storeu_ps(c${M}, vacc${M}x${ABC[0:8]});
110 $for N in range(8, NR, 8):
111 _mm256_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+8]});
112 c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
113
114 $for M in reversed(range(MR)):
115 a${M} = (const float*) ((uintptr_t) a${M} - kc);
116
117 nc -= ${NR};
118 } else {
119 $for LOG2N in reversed(range(NR.bit_length())):
120 $if NR != 1 << LOG2N:
121 if (nc & ${1 << LOG2N}) {
122 $if LOG2N >= 3:
123 $for M in reversed(range(MR)):
124 _mm256_storeu_ps(c${M}, vacc${M}x${ABC[0:8]});
125 $for N in range(8, 1 << LOG2N, 8):
126 _mm256_storeu_ps(c${M} + ${N}, vacc${M}x${ABC[N:N+8]});
127
128 $for M in reversed(range(MR)):
129 $for N in range(0, 1 << (LOG2N - 1), 8):
130 vacc${M}x${ABC[N:N+8]} = vacc${M}x${ABC[N + (1 << LOG2N):N + (1 << LOG2N)+8]};
131
132 $for M in reversed(range(MR)):
133 c${M} += ${1 << LOG2N};
134 $elif LOG2N == 2:
135 $for M in reversed(range(MR)):
136 _mm_storeu_ps(c${M}, vacc${M}x${ABC[0:4]});
137
138 $for M in reversed(range(MR)):
139 vacc${M}x${ABC[0:4]} = _mm256_extractf128_ps(vacc${M}x${ABC[0:8]}, 1);
140
141 $for M in reversed(range(MR)):
142 c${M} += 4;
143 $elif LOG2N == 1:
144 $for M in reversed(range(MR)):
145 _mm_storel_pi((__m64*) c${M}, vacc${M}x${ABC[0:4]});
146
147 $for M in reversed(range(MR)):
148 vacc${M}x${ABC[0:4]} = _mm_movehl_ps(vacc${M}x${ABC[0:4]}, vacc${M}x${ABC[0:4]});
149
150 $for M in reversed(range(MR)):
151 c${M} += 2;
152 $elif LOG2N == 0:
153 $for M in reversed(range(MR)):
154 _mm_store_ss(c${M}, vacc${M}x${ABC[0:4]});
155 }
156 $if LOG2N == 3:
157 $for M in reversed(range(MR)):
158 __m128 vacc${M}x${ABC[0:4]} = _mm256_castps256_ps128(vacc${M}x${ABC[0:8]});
159
160 nc = 0;
161 }
162 } while (nc != 0);
163}