| // Copyright 2021 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| $assert KERNEL_TILE >= 2 |
| $assert REQUANTIZATION == "FP32" |
| $assert DATATYPE in ["QC8", "QS8", "QU8"] |
| #include <assert.h> |
| $if VARIANT == "LRINT": |
| #include <math.h> |
| $elif VARIANT == "MAGIC": |
| |
| #include <fp16.h> |
| |
| #include <xnnpack/dwconv.h> |
| #include <xnnpack/math.h> |
| |
| |
| $FUNCTION_SUFFIX = "scalar_" + VARIANT.lower() if VARIANT else "scalar" |
| $PARAMS_STRUCT = ("" if DATATYPE == "QC8" else REQUANTIZATION.lower() + "_") + "scalar" + ("_" + VARIANT.lower() if VARIANT else "") |
| $PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_%s_conv_minmax_params" % DATATYPE.lower() |
| $XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" |
| void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${FUNCTION_SUFFIX}( |
| size_t channels, |
| size_t output_width, |
| const ${XINT8_T}** input, |
| const void* weights, |
| ${XINT8_T}* output, |
| size_t input_stride, |
| size_t output_increment, |
| size_t input_offset, |
| const ${XINT8_T}* zero, |
| const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) |
| { |
| assert(channels != 0); |
| assert(output_width != 0); |
| |
| $if DATATYPE != "QC8": |
| const float vscale = params->${PARAMS_STRUCT}.scale; |
| $if VARIANT == "LRINT": |
| const long voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; |
| const long voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; |
| const int32_t voutput_zero_point = params->${PARAMS_STRUCT}.output_zero_point; |
| $elif VARIANT == "MAGIC": |
| const float voutput_min_less_zero_point = params->${PARAMS_STRUCT}.output_min_less_zero_point; |
| const float voutput_max_less_zero_point = params->${PARAMS_STRUCT}.output_max_less_zero_point; |
| const float vmagic_bias = params->${PARAMS_STRUCT}.magic_bias; |
| const int32_t vmagic_bias_less_output_zero_point = params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point; |
| $if DATATYPE == "QU8": |
| const int32_t vkernel_zero_point = params->${PARAMS_STRUCT}.kernel_zero_point; |
| do { |
| $for K in range(KERNEL_TILE): |
| const ${XINT8_T}* i${K} = input[${K}]; |
| assert(i${K} != NULL); |
| if XNN_UNPREDICTABLE(i${K} != zero) { |
| i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset); |
| } |
| input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride); |
| |
| size_t c = channels; |
| const void* w = weights; |
| $if CHANNEL_TILE == 1: |
| do { |
| int32_t vacc = *((const int32_t*) w); |
| |
| $for K in range(KERNEL_TILE): |
| $if DATATYPE == "QU8": |
| const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++; |
| $else: |
| const int32_t vi${K} = (int32_t) *i${K}++; |
| $if DATATYPE == "QU8": |
| const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}] - vkernel_zero_point; |
| $else: |
| const int32_t vk${K} = ((const ${XINT8_T}*) ((uintptr_t) w + sizeof(int32_t)))[${K}]; |
| vacc += vi${K} * vk${K}; |
| |
| w = (const void*) ((uintptr_t) w + sizeof(int32_t) + ${KERNEL_TILE} * sizeof(${XINT8_T})); |
| |
| $if DATATYPE == "QC8": |
| $if CHANNEL_TILE % 4 != 0: |
| typedef XNN_UNALIGNED float unaligned_float; |
| const float vscale = *((const unaligned_float*) w); |
| w = (const void*) ((const float*) w + 1); |
| $else: |
| const float vscale = *((const float*) w); |
| w = (const void*) ((const float*) w + 1); |
| $if VARIANT == "LRINT": |
| const float vfpacc = (float) vacc * vscale; |
| long vrndacc = lrintf(vfpacc); |
| vrndacc = XNN_UNPREDICTABLE(vrndacc < voutput_min_less_zero_point) ? voutput_min_less_zero_point : vrndacc; |
| vrndacc = XNN_UNPREDICTABLE(vrndacc > voutput_max_less_zero_point) ? voutput_max_less_zero_point : vrndacc; |
| int32_t vout = (int32_t) vrndacc + voutput_zero_point; |
| $elif VARIANT == "MAGIC": |
| float vfpacc = (float) vacc * vscale; |
| vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point); |
| vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point); |
| vfpacc += vmagic_bias; |
| int32_t vout = (int32_t) fp32_to_bits(vfpacc) - vmagic_bias_less_output_zero_point; |
| |
| *output++ = (${XINT8_T}) vout; |
| } while (--c != 0); |
| $else: |
| for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { |
| $for C in range(CHANNEL_TILE): |
| int32_t vacc${C} = ((const int32_t*) w)[${C}]; |
| |
| $for K in range(KERNEL_TILE): |
| |
| $for C in range(CHANNEL_TILE): |
| $if DATATYPE == "QU8": |
| const int32_t vi${K}x${C} = (int32_t) (uint32_t) i${K}[${C}]; |
| $else: |
| const int32_t vi${K}x${C} = (int32_t) i${K}[${C}]; |
| i${K} += ${CHANNEL_TILE}; |
| |
| $for C in range(CHANNEL_TILE): |
| $if DATATYPE == "QU8": |
| const int32_t vk${K}x${C} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}] - vkernel_zero_point; |
| $else: |
| const int32_t vk${K}x${C} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE + C}]; |
| |
| $for C in range(CHANNEL_TILE): |
| vacc${C} += vi${K}x${C} * vk${K}x${C}; |
| |
| w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T})); |
| |
| $for C in range(CHANNEL_TILE): |
| float vfpacc${C} = (float) vacc${C}; |
| |
| $if DATATYPE == "QC8": |
| $if CHANNEL_TILE % 4 != 0: |
| typedef XNN_UNALIGNED float unaligned_float; |
| $for C in range(CHANNEL_TILE): |
| const float vscale${C} = ((const unaligned_float*) w)[${C}]; |
| $else: |
| $for C in range(CHANNEL_TILE): |
| const float vscale${C} = ((const float*) w)[${C}]; |
| w = (const void*) ((const float*) w + ${CHANNEL_TILE}); |
| |
| $for C in range(CHANNEL_TILE): |
| vfpacc${C} *= vscale${C}; |
| $else: |
| $for C in range(CHANNEL_TILE): |
| vfpacc${C} *= vscale; |
| |
| $if VARIANT == "LRINT": |
| $for C in range(CHANNEL_TILE): |
| long vrndacc${C} = lrintf(vfpacc${C}); |
| |
| $for C in range(CHANNEL_TILE): |
| vrndacc${C} = XNN_UNPREDICTABLE(vrndacc${C} < voutput_min_less_zero_point) ? voutput_min_less_zero_point : vrndacc${C}; |
| |
| $for C in range(CHANNEL_TILE): |
| vrndacc${C} = XNN_UNPREDICTABLE(vrndacc${C} > voutput_max_less_zero_point) ? voutput_max_less_zero_point : vrndacc${C}; |
| |
| $for C in range(CHANNEL_TILE): |
| int32_t vout${C} = (int32_t) vrndacc${C} + voutput_zero_point; |
| $elif VARIANT == "MAGIC": |
| $for C in range(CHANNEL_TILE): |
| vfpacc${C} = math_max_f32(vfpacc${C}, voutput_min_less_zero_point); |
| |
| $for C in range(CHANNEL_TILE): |
| vfpacc${C} = math_min_f32(vfpacc${C}, voutput_max_less_zero_point); |
| |
| $for C in range(CHANNEL_TILE): |
| vfpacc${C} += vmagic_bias; |
| |
| $for C in range(CHANNEL_TILE): |
| int32_t vout${C} = (int32_t) fp32_to_bits(vfpacc${C}) - vmagic_bias_less_output_zero_point; |
| |
| $for C in range(CHANNEL_TILE): |
| output[${C}] = (${XINT8_T}) vout${C}; |
| output += ${CHANNEL_TILE}; |
| } |
| if XNN_UNLIKELY(c != 0) { |
| $if CHANNEL_TILE == 2: |
| int32_t vacc = *((const int32_t*) w); |
| |
| $for K in range(KERNEL_TILE): |
| $if DATATYPE == "QU8": |
| const int32_t vi${K} = (int32_t) (uint32_t) *i${K}; |
| $else: |
| const int32_t vi${K} = (int32_t) *i${K}; |
| $if DATATYPE == "QU8": |
| const int32_t vk${K} = (int32_t) (uint32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE}] - vkernel_zero_point; |
| $else: |
| const int32_t vk${K} = (int32_t) ((const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)))[${K * CHANNEL_TILE}]; |
| vacc += vi${K} * vk${K}; |
| |
| $if DATATYPE == "QC8": |
| $if CHANNEL_TILE % 4 != 0: |
| typedef XNN_UNALIGNED float unaligned_float; |
| const float vscale = *((const unaligned_float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); |
| $else: |
| const float vscale = *((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); |
| $if VARIANT == "LRINT": |
| const float vfpacc = (float) vacc * vscale; |
| long vrndacc = lrintf(vfpacc); |
| vrndacc = XNN_UNPREDICTABLE(vrndacc < voutput_min_less_zero_point) ? voutput_min_less_zero_point : vrndacc; |
| vrndacc = XNN_UNPREDICTABLE(vrndacc > voutput_max_less_zero_point) ? voutput_max_less_zero_point : vrndacc; |
| int32_t vout = (int32_t) vrndacc + voutput_zero_point; |
| $elif VARIANT == "MAGIC": |
| float vfpacc = (float) vacc * vscale; |
| vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point); |
| vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point); |
| vfpacc += vmagic_bias; |
| int32_t vout = (int32_t) fp32_to_bits(vfpacc) - vmagic_bias_less_output_zero_point; |
| |
| *output++ = (${XINT8_T}) vout; |
| $else: |
| const ${XINT8_T}* k = (const ${XINT8_T}*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t)); |
| do { |
| int32_t vacc = *((const int32_t*) w); |
| w = (const void*) ((uintptr_t) w + sizeof(int32_t)); |
| |
| $for K in range(KERNEL_TILE): |
| $if DATATYPE == "QU8": |
| const int32_t vi${K} = (int32_t) (uint32_t) *i${K}++; |
| $else: |
| const int32_t vi${K} = (int32_t) *i${K}++; |
| $if DATATYPE == "QU8": |
| const int32_t vk${K} = (int32_t) (uint32_t) k[${K * CHANNEL_TILE}] - vkernel_zero_point; |
| $else: |
| const int32_t vk${K} = (int32_t) k[${K * CHANNEL_TILE}]; |
| vacc += vi${K} * vk${K}; |
| k += 1; |
| |
| $if DATATYPE == "QC8": |
| $if CHANNEL_TILE % 4 != 0: |
| typedef XNN_UNALIGNED float unaligned_float; |
| const float vscale = *((const unaligned_float*) ((uintptr_t) w + ${CHANNEL_TILE - 1} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); |
| $else: |
| const float vscale = *((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 1} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(${XINT8_T}))); |
| $if VARIANT == "LRINT": |
| const float vfpacc = (float) vacc * vscale; |
| long vrndacc = lrintf(vfpacc); |
| vrndacc = XNN_UNPREDICTABLE(vrndacc < voutput_min_less_zero_point) ? voutput_min_less_zero_point : vrndacc; |
| vrndacc = XNN_UNPREDICTABLE(vrndacc > voutput_max_less_zero_point) ? voutput_max_less_zero_point : vrndacc; |
| int32_t vout = (int32_t) vrndacc + voutput_zero_point; |
| $elif VARIANT == "MAGIC": |
| float vfpacc = (float) vacc * vscale; |
| vfpacc = math_max_f32(vfpacc, voutput_min_less_zero_point); |
| vfpacc = math_min_f32(vfpacc, voutput_max_less_zero_point); |
| vfpacc += vmagic_bias; |
| int32_t vout = (int32_t) fp32_to_bits(vfpacc) - vmagic_bias_less_output_zero_point; |
| |
| *output++ = (${XINT8_T}) vout; |
| } while (--c != 0); |
| } |
| |
| output = (${XINT8_T}*) ((uintptr_t) output + output_increment); |
| } while (--output_width != 0); |
| } |