QS8 Sigmoid operator

PiperOrigin-RevId: 395713278
diff --git a/include/xnnpack.h b/include/xnnpack.h
index 06cd2ff..1e8a891 100644
--- a/include/xnnpack.h
+++ b/include/xnnpack.h
@@ -2185,6 +2185,26 @@
   int8_t* output,
   pthreadpool_t threadpool);
 
+enum xnn_status xnn_create_sigmoid_nc_qs8(
+  size_t channels,
+  size_t input_stride,
+  size_t output_stride,
+  int8_t input_zero_point,
+  float input_scale,
+  int8_t output_zero_point,
+  float output_scale,
+  int8_t output_min,
+  int8_t output_max,
+  uint32_t flags,
+  xnn_operator_t* sigmoid_op_out);
+
+enum xnn_status xnn_setup_sigmoid_nc_qs8(
+  xnn_operator_t sigmoid_op,
+  size_t batch_size,
+  const int8_t* input,
+  int8_t* output,
+  pthreadpool_t threadpool);
+
 enum xnn_status xnn_create_subtract_nd_qs8(
   int8_t input1_zero_point,
   float input1_scale,
diff --git a/src/operator-strings.c b/src/operator-strings.c
index 83ce67a..5c7e8c9 100644
--- a/src/operator-strings.c
+++ b/src/operator-strings.c
@@ -134,6 +134,8 @@
       return "Resize Bilinear (NCHW, F32)";
     case xnn_operator_type_sigmoid_nc_f32:
       return "Sigmoid (NC, F32)";
+    case xnn_operator_type_sigmoid_nc_qs8:
+      return "Sigmoid (NC, QS8)";
     case xnn_operator_type_sigmoid_nc_qu8:
       return "Sigmoid (NC, QU8)";
     case xnn_operator_type_softmax_nc_f32:
diff --git a/src/operators/lut-elementwise-nc.c b/src/operators/lut-elementwise-nc.c
index 4a25518..b49daa7 100644
--- a/src/operators/lut-elementwise-nc.c
+++ b/src/operators/lut-elementwise-nc.c
@@ -23,6 +23,7 @@
     size_t output_stride,
     int32_t input_zero_point,
     float input_scale,
+    int32_t input_min,
     long output_zero_point,
     float output_scale,
     long output_min,
@@ -108,13 +109,13 @@
 
   uint8_t* lookup_table = lut_elementwise_op->lookup_table;
   const float inv_output_scale = 1.0f / output_scale;
-  for (int32_t i = 0; i < 256; i++) {
+  for (int32_t i = input_min; i < input_min + 256; i++) {
     const float dequantized_input = (i - input_zero_point) * input_scale;
     const float dequantized_output = init_fn(dequantized_input, init_params);
     long quantized_output = lrintf(dequantized_output * inv_output_scale) + output_zero_point;
     quantized_output = XNN_UNPREDICTABLE(quantized_output < output_min) ? output_min : quantized_output;
     quantized_output = XNN_UNPREDICTABLE(quantized_output > output_max) ? output_max : quantized_output;
-    lookup_table[i] = (uint8_t) quantized_output;
+    lookup_table[(uint8_t) i] = (uint8_t) quantized_output;
   }
 
   lut_elementwise_op->channels = channels;
@@ -178,7 +179,7 @@
 
   return create_lut_elementwise_nc(
     channels, input_stride, output_stride,
-    (int32_t) (uint32_t) input_zero_point, input_scale,
+    (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
     (long) (unsigned long) output_zero_point, output_scale,
     (long) (unsigned long) output_min, (long) (unsigned long) output_max,
     flags,
@@ -190,6 +191,43 @@
   return signbit(x) ? 1.0f / (1.0f + expf(-x)) : 1.0f - 1.0f / (1.0f + expf(x));
 }
 
+enum xnn_status xnn_create_sigmoid_nc_qs8(
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    int8_t input_zero_point,
+    float input_scale,
+    int8_t output_zero_point,
+    float output_scale,
+    int8_t output_min,
+    int8_t output_max,
+    uint32_t flags,
+    xnn_operator_t* sigmoid_op_out)
+{
+  if (output_scale != 0x1.0p-8f) {
+    xnn_log_error(
+      "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
+      xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_scale);
+    return xnn_status_unsupported_parameter;
+  }
+
+  if (output_zero_point != -128) {
+    xnn_log_error(
+      "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of -128 is supported",
+      xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_zero_point);
+    return xnn_status_unsupported_parameter;
+  }
+
+  return create_lut_elementwise_nc(
+    channels, input_stride, output_stride,
+    (int32_t) input_zero_point, input_scale, INT8_MIN,
+    (long) output_zero_point, output_scale,
+    (long) output_min, (long) output_max,
+    flags,
+    (xnn_lut_init_fn) &calculate_sigmoid, NULL,
+    xnn_operator_type_sigmoid_nc_qs8, sigmoid_op_out);
+}
+
 enum xnn_status xnn_create_sigmoid_nc_qu8(
     size_t channels,
     size_t input_stride,
@@ -219,7 +257,7 @@
 
   return create_lut_elementwise_nc(
     channels, input_stride, output_stride,
-    (int32_t) (uint32_t) input_zero_point, input_scale,
+    (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
     (long) (unsigned long) output_zero_point, output_scale,
     (long) (unsigned long) output_min, (long) (unsigned long) output_max,
     flags,
@@ -303,6 +341,18 @@
     batch_size, input, output);
 }
 
+enum xnn_status xnn_setup_sigmoid_nc_qs8(
+    xnn_operator_t sigmoid_op,
+    size_t batch_size,
+    const int8_t* input,
+    int8_t* output,
+    pthreadpool_t threadpool)
+{
+  return setup_lut_elementwise_nc(
+    sigmoid_op, xnn_operator_type_sigmoid_nc_qs8,
+    batch_size, input, output);
+}
+
 enum xnn_status xnn_setup_sigmoid_nc_qu8(
     xnn_operator_t sigmoid_op,
     size_t batch_size,
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 13af8e1..d73b6fb 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -90,6 +90,7 @@
   xnn_operator_type_resize_bilinear_nchw_f32,
   xnn_operator_type_resize_bilinear_nhwc_f32,
   xnn_operator_type_sigmoid_nc_f32,
+  xnn_operator_type_sigmoid_nc_qs8,
   xnn_operator_type_sigmoid_nc_qu8,
   xnn_operator_type_softmax_nc_f32,
   xnn_operator_type_softmax_nc_qu8,
diff --git a/test/sigmoid-nc.cc b/test/sigmoid-nc.cc
index 3f79efe..587f012 100644
--- a/test/sigmoid-nc.cc
+++ b/test/sigmoid-nc.cc
@@ -11,6 +11,212 @@
 #include "sigmoid-operator-tester.h"
 
 
+TEST(SIGMOID_NC_QS8, unit_batch) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, unit_batch_with_qmin) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .qmin(128)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, unit_batch_with_qmax) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(1)
+      .channels(channels)
+      .qmax(128)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, unit_batch_with_input_scale) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 10.0f) {
+      SigmoidOperatorTester()
+        .batch_size(1)
+        .channels(channels)
+        .input_scale(input_scale)
+        .iterations(1)
+        .TestQS8();
+    }
+  }
+}
+
+TEST(SIGMOID_NC_QS8, unit_batch_with_input_zero_point) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
+      SigmoidOperatorTester()
+        .batch_size(1)
+        .channels(channels)
+        .input_zero_point(uint8_t(input_zero_point))
+        .iterations(1)
+        .TestQS8();
+    }
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch_with_input_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch_with_output_stride) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .output_stride(117)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch_with_qmin) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .qmin(128)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch_with_qmax) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .qmax(128)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch_with_input_scale) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 10.0f) {
+      SigmoidOperatorTester()
+        .batch_size(3)
+        .channels(channels)
+        .input_scale(input_scale)
+        .iterations(1)
+        .TestQS8();
+    }
+  }
+}
+
+TEST(SIGMOID_NC_QS8, small_batch_with_input_zero_point) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
+      SigmoidOperatorTester()
+        .batch_size(3)
+        .channels(channels)
+        .input_zero_point(uint8_t(input_zero_point))
+        .iterations(1)
+        .TestQS8();
+    }
+  }
+}
+
+TEST(SIGMOID_NC_QS8, strided_batch) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .output_stride(117)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, strided_batch_with_qmin) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .output_stride(117)
+      .qmin(128)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, strided_batch_with_qmax) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    SigmoidOperatorTester()
+      .batch_size(3)
+      .channels(channels)
+      .input_stride(129)
+      .output_stride(117)
+      .qmax(128)
+      .iterations(3)
+      .TestQS8();
+  }
+}
+
+TEST(SIGMOID_NC_QS8, strided_batch_with_input_scale) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (float input_scale = 1.0e-2f; input_scale < 1.0e+2f; input_scale *= 10.0f) {
+      SigmoidOperatorTester()
+        .batch_size(3)
+        .channels(channels)
+        .input_stride(129)
+        .output_stride(117)
+        .input_scale(input_scale)
+        .iterations(1)
+        .TestQS8();
+    }
+  }
+}
+
+TEST(SIGMOID_NC_QS8, strided_batch_with_input_zero_point) {
+  for (size_t channels = 1; channels < 100; channels += 15) {
+    for (int32_t input_zero_point = 0; input_zero_point <= 255; input_zero_point += 51) {
+      SigmoidOperatorTester()
+        .batch_size(3)
+        .channels(channels)
+        .input_stride(129)
+        .output_stride(117)
+        .input_zero_point(uint8_t(input_zero_point))
+        .iterations(1)
+        .TestQS8();
+    }
+  }
+}
+
 TEST(SIGMOID_NC_QU8, unit_batch) {
   for (size_t channels = 1; channels < 100; channels += 15) {
     SigmoidOperatorTester()
diff --git a/test/sigmoid-operator-tester.h b/test/sigmoid-operator-tester.h
index 9105e07..56d727e 100644
--- a/test/sigmoid-operator-tester.h
+++ b/test/sigmoid-operator-tester.h
@@ -130,6 +130,69 @@
     return this->iterations_;
   }
 
+  void TestQS8() const {
+    std::random_device random_device;
+    auto rng = std::mt19937(random_device());
+    auto i8rng = std::bind(
+      std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
+      std::ref(rng));
+
+    std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
+    std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels());
+    std::vector<float> output_ref(batch_size() * channels());
+    for (size_t iteration = 0; iteration < iterations(); iteration++) {
+      std::generate(input.begin(), input.end(), std::ref(i8rng));
+      std::fill(output.begin(), output.end(), 0xA5);
+
+      // Compute reference results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          const float x = input_scale() *
+            (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80));
+          const float sigmoid_x = 1.0f / (1.0f + std::exp(-x));
+          const float scaled_sigmoid_x = sigmoid_x / output_scale();
+          float y = scaled_sigmoid_x;
+          y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80));
+          y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80));
+          output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80);
+        }
+      }
+
+      // Create, setup, run, and destroy Sigmoid operator.
+      ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
+      xnn_operator_t sigmoid_op = nullptr;
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_create_sigmoid_nc_qs8(
+          channels(), input_stride(), output_stride(),
+          int8_t(input_zero_point() - 0x80), input_scale(),
+          int8_t(output_zero_point() - 0x80), output_scale(),
+          int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
+          0, &sigmoid_op));
+      ASSERT_NE(nullptr, sigmoid_op);
+
+      // Smart pointer to automatically delete sigmoid_op.
+      std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sigmoid_op(sigmoid_op, xnn_delete_operator);
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_setup_sigmoid_nc_qs8(
+          sigmoid_op,
+          batch_size(),
+          input.data(), output.data(),
+          nullptr /* thread pool */));
+
+      ASSERT_EQ(xnn_status_success,
+        xnn_run_operator(sigmoid_op, nullptr /* thread pool */));
+
+      // Verify results.
+      for (size_t i = 0; i < batch_size(); i++) {
+        for (size_t c = 0; c < channels(); c++) {
+          ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
+        }
+      }
+    }
+  }
+
   void TestQU8() const {
     std::random_device random_device;
     auto rng = std::mt19937(random_device());