blob: a34d6fddd157cb9ee6e2cc50fbccfaf996979a27 [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <stddef.h>
#include <stdint.h>
#include <pthreadpool.h>
#include <xnnpack/requantization.h>
#include <xnnpack/compute.h>
enum xnn_ukernel_type {
xnn_ukernel_type_none = 0,
xnn_ukernel_type_add,
xnn_ukernel_type_argmax_pooling,
xnn_ukernel_type_average_pooling,
xnn_ukernel_type_channel_shuffle,
xnn_ukernel_type_clamp,
xnn_ukernel_type_igemm,
xnn_ukernel_type_dconv2d_hwc2spchw,
xnn_ukernel_type_dwconv,
xnn_ukernel_type_gemm,
xnn_ukernel_type_global_average_pooling,
xnn_ukernel_type_hswish,
xnn_ukernel_type_lut,
xnn_ukernel_type_max_pooling,
xnn_ukernel_type_pad,
xnn_ukernel_type_pixelwise_average_pooling,
xnn_ukernel_type_prelu,
xnn_ukernel_type_softargmax,
xnn_ukernel_type_spmm,
xnn_ukernel_type_subconv2d,
xnn_ukernel_type_unpooling,
xnn_ukernel_type_vmulcaddc,
};
enum xnn_operator_type {
xnn_operator_type_none = 0,
xnn_operator_type_add_f32,
xnn_operator_type_add_q8,
xnn_operator_type_argmax_pooling_f32,
xnn_operator_type_average_pooling_f32,
xnn_operator_type_average_pooling_q8,
xnn_operator_type_channel_pad_x32,
xnn_operator_type_channel_shuffle_x8,
xnn_operator_type_channel_shuffle_x32,
xnn_operator_type_clamp_f32,
xnn_operator_type_clamp_u8,
xnn_operator_type_convolution_f32,
xnn_operator_type_convolution_spnchw_f32,
xnn_operator_type_convolution_q8,
xnn_operator_type_deconvolution_f32,
xnn_operator_type_deconvolution_q8,
xnn_operator_type_fully_connected_f32,
xnn_operator_type_fully_connected_q8,
xnn_operator_type_global_average_pooling_f32,
xnn_operator_type_global_average_pooling_q8,
xnn_operator_type_global_average_pooling_spnchw_f32,
xnn_operator_type_hswish_f32,
xnn_operator_type_leaky_relu_q8,
xnn_operator_type_max_pooling_f32,
xnn_operator_type_max_pooling_u8,
xnn_operator_type_prelu_f32,
xnn_operator_type_sigmoid_q8,
xnn_operator_type_softargmax_q8,
xnn_operator_type_unpooling_x32,
};
struct xnn_ukernel_dconv2d {
union {
xnn_conv_hwc2spchw_ukernel_function hwc2spchw_function;
xnn_conv_hwc_ukernel_function hwc_function;
};
uint8_t output_height_tile;
uint8_t output_channel_tile;
};
struct xnn_ukernel_dwconv {
union {
xnn_dwconv_up_ukernel_function unipass_function;
xnn_dwconv_mp_ukernel_function multipass_function;
};
uint8_t mr;
uint8_t qr;
};
// Direct 2D Depthwise Convolution
struct xnn_ukernel_dwconv2d {
union {
xnn_dwconv_spchw_ukernel_function spchw_function;
};
uint8_t input_width_tile;
uint8_t output_width_tile;
};
struct xnn_ukernel_gemm {
xnn_gemm_ukernel_function default_function;
xnn_gemm_ukernel_function mr1_function;
uint8_t mr;
uint8_t nr;
uint8_t kr;
};
struct xnn_ukernel_igemm {
xnn_igemm_ukernel_function default_function;
xnn_igemm_ukernel_function mr1_function;
uint8_t mr;
uint8_t nr;
uint8_t kr;
};
struct xnn_ukernel_spmm {
xnn_spmm_ukernel_function function;
uint8_t mr;
};
struct xnn_ukernel_vmulcaddc {
xnn_vmulcaddc_ukernel_function function;
uint8_t mr;
};
struct xnn_ukernel {
enum xnn_ukernel_type type;
union {
struct xnn_ukernel_dconv2d dconv2d;
struct xnn_ukernel_dwconv dwconv;
struct xnn_ukernel_dwconv2d dwconv2d;
struct xnn_ukernel_gemm gemm;
struct xnn_ukernel_igemm igemm;
struct xnn_ukernel_spmm spmm;
struct xnn_ukernel_vmulcaddc vmulcaddc;
};
};
enum xnn_run_state {
xnn_run_state_invalid = 0,
xnn_run_state_ready,
xnn_run_state_skip,
};
struct subconvolution_params {
void* weights;
size_t w_stride;
const void** indirection_buffer;
void* output;
size_t slice_width;
size_t slice_height;
size_t indirection_y_stride;
size_t indirection_x_stride;
/* kernel_size * mr * sizeof(void*) */
size_t scaled_kernel_size;
};
struct xnn_operator {
size_t batch_size;
uint32_t padding_top;
uint32_t padding_right;
uint32_t padding_bottom;
uint32_t padding_left;
uint32_t adjustment_height;
uint32_t adjustment_width;
uint32_t kernel_height;
uint32_t kernel_width;
uint32_t stride_height;
uint32_t stride_width;
uint32_t dilation_height;
uint32_t dilation_width;
uint32_t groups;
size_t group_channels;
size_t group_input_channels;
size_t group_output_channels;
size_t channels;
size_t pad_before_channels;
size_t pad_after_channels;
uint32_t pad_value;
size_t input_height;
size_t input_width;
size_t input_pixel_stride;
const void* input;
const void** indirection_buffer;
void* a_sum;
size_t input2_pixel_stride;
const void* input2;
size_t output_height;
size_t output_width;
size_t output_pixel_stride;
void* output;
void* packed_weights;
// Total number of non-zero kernel elements when weights use sparse representation.
size_t num_nonzero_values;
// Total number of non-zero kernel blocks when weights use sparse representation.
size_t num_nonzero_blocks;
// Total number of output channel blocks when weights use sparse representation.
size_t num_output_channel_blocks;
// Input channel corresponding to the first non-zero kernel element.
size_t first_input_channel;
float input_scale;
float output_scale;
uint8_t input_zero_point;
uint8_t kernel_zero_point;
uint8_t output_zero_point;
uint8_t output_min;
uint8_t output_max;
size_t valid_batch_size;
size_t last_input_height;
size_t last_input_width;
const void* last_input;
void* last_output;
void* zero_buffer;
void* lookup_table;
void* pixelwise_buffer;
struct subconvolution_params* subconvolution_buffer;
union {
union xnn_f32_avgpool_params f32_avgpool_params;
union xnn_f32_gavgpool_params f32_gavgpool_params;
union xnn_f32_hswish_params f32_hswish_params;
union xnn_f32_output_params f32_output_params;
union xnn_f32_spchw_params f32_spchw_params;
union xnn_q8_add_params q8_add_params;
union xnn_q8_avgpool_params q8_avgpool_params;
union xnn_q8_gemm_params q8_gemm_params;
union xnn_u8_output_params u8_output_params;
};
enum xnn_operator_type type;
struct xnn_ukernel ukernel;
struct compute_parameters compute;
struct compute_parameters compute2;
union {
struct add_contiguous_context add_contiguous;
struct add_strided_context add_strided;
struct argmax_pooling_context argmax_pooling;
struct average_pooling_context average_pooling;
struct channel_pad_context channel_pad;
struct channel_shuffle_context channel_shuffle;
struct dconv2d_context dconv2d;
struct dwconv2d_context dwconv2d;
struct dwconv_context dwconv;
struct gemm_context gemm;
struct global_average_pooling_context global_average_pooling;
struct global_average_pooling_spnchw_context global_average_pooling_spnchw;
struct igemm_context igemm;
struct lut_contiguous_context lut_contiguous;
struct lut_strided_context lut_strided;
struct max_pooling_context max_pooling;
struct pixelwise_average_pooling_context pixelwise_average_pooling;
struct prelu_context prelu;
struct spmm_context spmm;
struct subconv_context subconv;
struct u8_softargmax_context u8_softargmax;
struct univector_contiguous_context univector_contiguous;
struct univector_strided_context univector_strided;
struct unpooling_context unpooling;
struct vmulcaddc_context vmulcaddc;
} context;
enum xnn_run_state state;
};