Marat Dukhan | 80fc932 | 2019-09-29 21:06:36 -0700 | [diff] [blame] | 1 | // Copyright 2019 Google LLC |
| 2 | // |
| 3 | // This source code is licensed under the BSD-style license found in the |
| 4 | // LICENSE file in the root directory of this source tree. |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 5 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 6 | $assert CHANNEL_TILE % 4 == 0 |
| 7 | $assert CHANNEL_TILE >= 4 |
| 8 | $assert ROW_TILE >= 1 |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 9 | $ABC = "0123456789ABCDEFGHIJKLMN" |
| 10 | #include <assert.h> |
| 11 | |
| 12 | #include <arm_neon.h> |
| 13 | |
| 14 | #include <xnnpack/math.h> |
| 15 | #include <xnnpack/vmulcaddc.h> |
| 16 | |
| 17 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 18 | void xnn_f32_vmulcaddc_ukernel_c${CHANNEL_TILE}__${"neonfma" if FMA else "neon"}_${ROW_TILE}x( |
| 19 | size_t rows, |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 20 | size_t channels, |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 21 | const float*restrict input, |
| 22 | size_t input_stride, |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 23 | const float*restrict weights, |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 24 | float*restrict output, |
| 25 | size_t output_stride, |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 26 | const union xnn_f32_output_params params[restrict static 1]) |
| 27 | { |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 28 | assert(rows != 0); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 29 | assert(channels != 0); |
| 30 | assert(channels % sizeof(float) == 0); |
| 31 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 32 | const float* i0 = input; |
| 33 | float* o0 = output; |
| 34 | $for M in range(1, ROW_TILE): |
| 35 | const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride); |
| 36 | float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 37 | $if M % 2 == 0: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 38 | if XNN_UNPREDICTABLE(rows <= ${M}) { |
| 39 | i${M} = i${M-1}; |
| 40 | o${M} = o${M-1}; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 41 | } |
| 42 | $else: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 43 | if XNN_UNPREDICTABLE(rows < ${M+1}) { |
| 44 | i${M} = i${M-1}; |
| 45 | o${M} = o${M-1}; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 46 | } |
| 47 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 48 | const size_t input_increment = input_stride * ${ROW_TILE} - channels; |
| 49 | const size_t output_increment = output_stride * ${ROW_TILE} - channels; |
| 50 | |
Frank Barchard | fcfdc0e | 2019-10-21 15:58:42 -0700 | [diff] [blame] | 51 | const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min); |
| 52 | const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 53 | do { |
| 54 | const float* w = weights; |
| 55 | size_t c = channels; |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 56 | for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) { |
| 57 | $for C in range(0, CHANNEL_TILE, 4): |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 58 | const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32(w); w += 4; |
| 59 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 60 | $for M in range(ROW_TILE): |
| 61 | $for C in range(0, CHANNEL_TILE, 4): |
| 62 | float32x4_t vacc${M}x${ABC[C:C+4]} = vld1q_f32(i${M}); i${M} += 4; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 63 | |
| 64 | $if not FMA: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 65 | $for M in range(ROW_TILE): |
| 66 | $for C in range(0, CHANNEL_TILE, 4): |
| 67 | vacc${M}x${ABC[C:C+4]} = vmulq_f32(vacc${M}x${ABC[C:C+4]}, vscale${ABC[C:C+4]}); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 68 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 69 | $for C in range(0, CHANNEL_TILE, 4): |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 70 | const float32x4_t vbias${ABC[C:C+4]} = vld1q_f32(w); w += 4; |
| 71 | |
| 72 | $if not FMA: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 73 | $for M in range(ROW_TILE): |
| 74 | $for C in range(0, CHANNEL_TILE, 4): |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 75 | vacc${M}x${ABC[C:C+4]} = vaddq_f32(vacc${M}x${ABC[C:C+4]}, vbias${ABC[C:C+4]}); |
| 76 | $else: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 77 | $for M in range(ROW_TILE): |
| 78 | $for C in range(0, CHANNEL_TILE, 4): |
| 79 | vacc${M}x${ABC[C:C+4]} = vfmaq_f32(vbias${ABC[C:C+4]}, vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]}); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 80 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 81 | $for M in range(ROW_TILE): |
| 82 | $for C in range(0, CHANNEL_TILE, 4): |
Frank Barchard | fcfdc0e | 2019-10-21 15:58:42 -0700 | [diff] [blame] | 83 | vacc${M}x${ABC[C:C+4]} = vmaxq_f32(vacc${M}x${ABC[C:C+4]}, vmin); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 84 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 85 | $for M in range(ROW_TILE): |
| 86 | $for C in range(0, CHANNEL_TILE, 4): |
Frank Barchard | fcfdc0e | 2019-10-21 15:58:42 -0700 | [diff] [blame] | 87 | vacc${M}x${ABC[C:C+4]} = vminq_f32(vacc${M}x${ABC[C:C+4]}, vmax); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 88 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 89 | $for M in range(ROW_TILE): |
| 90 | $for C in range(0, CHANNEL_TILE, 4): |
| 91 | vst1q_f32(o${M}, vacc${M}x${ABC[C:C+4]}); o${M} += 4; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 92 | } |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 93 | $if CHANNEL_TILE > 4: |
| 94 | for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) { |
| 95 | const float32x4_t vscale0123 = vld1q_f32(w); w += 4; |
| 96 | |
| 97 | $for M in range(ROW_TILE): |
| 98 | float32x4_t vacc${M}x0123 = vld1q_f32(i${M}); i${M} += 4; |
| 99 | |
| 100 | $if not FMA: |
| 101 | $for M in range(ROW_TILE): |
| 102 | vacc${M}x0123 = vmulq_f32(vacc${M}x0123, vscale0123); |
| 103 | |
| 104 | const float32x4_t vbias0123 = vld1q_f32(w + ${CHANNEL_TILE - 4}); |
| 105 | |
| 106 | $if not FMA: |
| 107 | $for M in range(ROW_TILE): |
| 108 | vacc${M}x0123 = vaddq_f32(vacc${M}x0123, vbias0123); |
| 109 | $else: |
| 110 | $for M in range(ROW_TILE): |
| 111 | vacc${M}x0123 = vfmaq_f32(vbias0123, vscale0123, vacc${M}x0123); |
| 112 | |
| 113 | $for M in range(ROW_TILE): |
| 114 | vacc${M}x0123 = vmaxq_f32(vacc${M}x0123, vmin); |
| 115 | |
| 116 | $for M in range(ROW_TILE): |
| 117 | vacc${M}x0123 = vminq_f32(vacc${M}x0123, vmax); |
| 118 | |
| 119 | $for M in range(ROW_TILE): |
| 120 | vst1q_f32(o${M}, vacc${M}x0123); o${M} += 4; |
| 121 | } |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 122 | if XNN_UNLIKELY(c != 0) { |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 123 | const float32x4_t vscale0123 = vld1q_f32(w); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 124 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 125 | $for M in range(ROW_TILE): |
| 126 | float32x4_t vacc${M}x0123 = vld1q_f32(i${M}); i${M} = (const float*) ((uintptr_t) i${M} + c); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 127 | |
| 128 | $if not FMA: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 129 | $for M in range(ROW_TILE): |
| 130 | vacc${M}x0123 = vmulq_f32(vacc${M}x0123, vscale0123); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 131 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 132 | const float32x4_t vbias0123 = vld1q_f32(w + ${CHANNEL_TILE}); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 133 | |
| 134 | $if not FMA: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 135 | $for M in range(ROW_TILE): |
| 136 | vacc${M}x0123 = vaddq_f32(vacc${M}x0123, vbias0123); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 137 | $else: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 138 | $for M in range(ROW_TILE): |
| 139 | vacc${M}x0123 = vfmaq_f32(vbias0123, vscale0123, vacc${M}x0123); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 140 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 141 | $for M in range(ROW_TILE): |
| 142 | vacc${M}x0123 = vmaxq_f32(vacc${M}x0123, vmin); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 143 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 144 | $for M in range(ROW_TILE): |
| 145 | vacc${M}x0123 = vminq_f32(vacc${M}x0123, vmax); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 146 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 147 | $for M in range(ROW_TILE): |
| 148 | float32x2_t vacc${M}x01 = vget_low_f32(vacc${M}x0123); |
| 149 | if (c & (2 * sizeof(float))) { |
| 150 | $for M in range(ROW_TILE): |
| 151 | vst1_f32(o${M}, vacc${M}x01); o${M} += 2; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 152 | |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 153 | $for M in range(ROW_TILE): |
| 154 | vacc${M}x01 = vget_high_f32(vacc${M}x0123); |
| 155 | } |
| 156 | if (c & (1 * sizeof(float))) { |
| 157 | $for M in range(ROW_TILE): |
| 158 | vst1_lane_f32(o${M}, vacc${M}x01, 0); o${M} += 1; |
| 159 | } |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 160 | } |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 161 | $for M in range(ROW_TILE): |
| 162 | i${M} = (const float*) ((uintptr_t) i${M} + input_increment); |
| 163 | o${M} = (float*) ((uintptr_t) o${M} + output_increment); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 164 | $if M % 2 == 1: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 165 | if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M + 1}) { |
| 166 | i${M} = i${M-1}; |
| 167 | o${M} = o${M-1}; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 168 | } |
| 169 | $elif M != 0: |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 170 | if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) { |
| 171 | i${M} = i${M-1}; |
| 172 | o${M} = o${M-1}; |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 173 | } |
Marat Dukhan | 49e6ee9 | 2019-11-06 15:55:29 -0800 | [diff] [blame] | 174 | rows = doz(rows, ${ROW_TILE}); |
| 175 | } while (rows != 0); |
XNNPACK Team | b455b12 | 2019-09-27 18:10:33 -0700 | [diff] [blame] | 176 | } |