blob: 64b573ae83edbaea77f195d3ee7b75465ce6e8c4 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <xnnpack/assembly.h>
# void xnn_qs8_gemm_minmax_ukernel_1x16c4__aarch64_neondot_ld64(
# size_t mr, x0
# size_t nc, x1
# size_t kc, x2 / x0
# const int8_t* restrict a, x3
# size_t a_stride, (x4)
# const void* restrict w, x5
# int8_t* restrict c, x6
# size_t cm_stride, (x7)
# size_t cn_stride, [sp] -> x12
# const union xnn_qs8_gemm_params params) [sp + 8] -> x11
# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
# Register usage
# A0 x3 v0
# B x5 v4 v5 v6 v7 v16 v17 v18 v19
# C0 x6 v28 v29 v30 v31
# unused v8 v9 v10 v11 v12 v13 v14 v15
BEGIN_FUNCTION xnn_qs8_gemm_minmax_ukernel_1x16c4__aarch64_neondot_ld64
0:
# Load initial bias from w into accumulators
LDP q28, q29, [x5], 32
LDP q30, q31, [x5], 32
AND x8, x2, 7 // remainder is 1 to 7 bytes.
LDR x11, [sp, 8] // params
# Is there at least 8 bytes?
SUBS x0, x2, 8 // k = kc - 8
B.LO 3f
# Main loop - 8 bytes of A
.p2align 3
1:
LDR d0, [x3], 8
LDR q16, [x5, 0]
SUBS x0, x0, 8
LDR q17, [x5, 16]
SDOT v28.4s, v16.16b, v0.4b[0]
LDR q18, [x5, 32]
SDOT v29.4s, v17.16b, v0.4b[0]
LDR q19, [x5, 48]
SDOT v30.4s, v18.16b, v0.4b[0]
LDR q4, [x5, 64]
SDOT v31.4s, v19.16b, v0.4b[0]
LDR q5, [x5, 80]
SDOT v28.4s, v4.16b, v0.4b[1]
LDR q6, [x5, 96]
SDOT v29.4s, v5.16b, v0.4b[1]
LDR q7, [x5, 112]
SDOT v30.4s, v6.16b, v0.4b[1]
ADD x5, x5, 128
SDOT v31.4s, v7.16b, v0.4b[1]
B.HS 1b
# Is there a remainder?- 1 to 7 bytes of A
CBNZ x8, 3f
.p2align 3
2:
# Apply params - scale, shift, bias and clamp
LD1R {v0.4s}, [x11], 4
SQRDMULH v4.4s, v28.4s, v0.4s
LD1R {v1.4s}, [x11], 4
SQRDMULH v5.4s, v29.4s, v0.4s
SQRDMULH v6.4s, v30.4s, v0.4s
CMEQ v2.4s, v1.4s, 0
SQRDMULH v7.4s, v31.4s, v0.4s
BIC v28.16b, v28.16b, v2.16b
BIC v29.16b, v29.16b, v2.16b
BIC v30.16b, v30.16b, v2.16b
BIC v31.16b, v31.16b, v2.16b
SSRA v4.4s, v28.4s, 31 // signed shift right accumulate
SSRA v5.4s, v29.4s, 31
SSRA v6.4s, v30.4s, 31
SSRA v7.4s, v31.4s, 31
SRSHL v4.4s, v4.4s, v1.4s // signed rounding shift left
SRSHL v5.4s, v5.4s, v1.4s
SRSHL v6.4s, v6.4s, v1.4s
SRSHL v7.4s, v7.4s, v1.4s
LD1R {v2.8h}, [x11], 2 // add bias
SQXTN v4.4h, v4.4s
SQXTN v6.4h, v6.4s
SQXTN2 v4.8h, v5.4s
SQXTN2 v6.8h, v7.4s
LD2R {v0.16b, v1.16b}, [x11] // clamp to min/max
SQADD v4.8h, v4.8h, v2.8h
SQADD v6.8h, v6.8h, v2.8h
LDR x12, [sp] // cn_stride
SQXTN v4.8b, v4.8h
SQXTN2 v4.16b, v6.8h
SUBS x1, x1, 16
SMAX v4.16b, v4.16b, v0.16b
SMIN v4.16b, v4.16b, v1.16b
B.LO 4f
# Store full 1 x 16
ST1 {v4.16b}, [x6], x12
SUB x3, x3, x2 // a0 -= kc
B.NE 0b
RET
# Remainder- 1 to 7 bytes of A
.p2align 3
3:
LD1 {v0.8b}, [x3], x8
LDR q16, [x5, 0]
CMP x8, 4
LDR q17, [x5, 16]
SDOT v28.4s, v16.16b, v0.4b[0]
LDR q18, [x5, 32]
SDOT v29.4s, v17.16b, v0.4b[0]
LDR q19, [x5, 48]
SDOT v30.4s, v18.16b, v0.4b[0]
ADD x5, x5, 64
SDOT v31.4s, v19.16b, v0.4b[0]
B.LS 2b
LDR q4, [x5, 0]
LDR q5, [x5, 16]
SDOT v28.4s, v4.16b, v0.4b[1]
LDR q6, [x5, 32]
SDOT v29.4s, v5.16b, v0.4b[1]
LDR q7, [x5, 48]
SDOT v30.4s, v6.16b, v0.4b[1]
ADD x5, x5, 64
SDOT v31.4s, v7.16b, v0.4b[1]
B 2b
# Store odd width
.p2align 3
4:
TBZ x1, 3, 5f
STR d4, [x6], 8
DUP d4, v4.d[1]
5:
TBZ x1, 2, 6f
STR s4, [x6], 4
DUP s4, v4.s[1]
6:
TBZ x1, 1, 7f
ST1 {v4.h}[0], [x6], 2
DUP h4, v4.h[1]
7:
TBZ x1, 0, 8f
ST1 {v4.b}[0], [x6]
8:
RET
END_FUNCTION xnn_qs8_gemm_minmax_ukernel_1x16c4__aarch64_neondot_ld64
#ifdef __ELF__
.section ".note.GNU-stack","",%progbits
#endif