RNN VAD: fix pitch gain type and change pitch period type

The pitch gain type in ComputePitchGainThreshold() is wrong
(size_t instead of float).
The pitch period is an unsigned integer type, but it is safer to
switch to a signed type and add checks on the sign.

Bug: webrtc:9076
Change-Id: If69d182071edab9750a320f0fbfac24aa8052ee0
Reviewed-on: https://webrtc-review.googlesource.com/c/117302
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#26259}
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_info.h b/modules/audio_processing/agc2/rnn_vad/pitch_info.h
index f0998d1..c9fdd18 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_info.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_info.h
@@ -18,8 +18,8 @@
 // strength of the pitch (the higher, the stronger).
 struct PitchInfo {
   PitchInfo() : period(0), gain(0.f) {}
-  PitchInfo(size_t p, float g) : period(p), gain(g) {}
-  size_t period;
+  PitchInfo(int p, float g) : period(p), gain(g) {}
+  int period;
   float gain;
 };
 
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index 32ee8c0..7c17dfb 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -128,12 +128,12 @@
 //     sn = mex({n * i for i in S} | {1})
 //     S = S | {Fraction(1, n), Fraction(sn, n)}
 //     print(sn, end=', ')
-constexpr std::array<size_t, 14> kSubHarmonicMultipliers = {
+constexpr std::array<int, 14> kSubHarmonicMultipliers = {
     {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
 
 // Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
 // a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
-constexpr std::array<size_t, 14> kInitialPitchPeriodThresholds = {
+constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
     {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
 
 }  // namespace
@@ -147,31 +147,34 @@
   }
 }
 
-float ComputePitchGainThreshold(size_t candidate_pitch_period,
-                                size_t pitch_period_ratio,
-                                size_t initial_pitch_period,
+float ComputePitchGainThreshold(int candidate_pitch_period,
+                                int pitch_period_ratio,
+                                int initial_pitch_period,
                                 float initial_pitch_gain,
-                                size_t prev_pitch_period,
-                                size_t prev_pitch_gain) {
+                                int prev_pitch_period,
+                                float prev_pitch_gain) {
   // Map arguments to more compact aliases.
-  const size_t& t1 = candidate_pitch_period;
-  const size_t& k = pitch_period_ratio;
-  const size_t& t0 = initial_pitch_period;
+  const int& t1 = candidate_pitch_period;
+  const int& k = pitch_period_ratio;
+  const int& t0 = initial_pitch_period;
   const float& g0 = initial_pitch_gain;
-  const size_t& t_prev = prev_pitch_period;
-  const size_t& g_prev = prev_pitch_gain;
+  const int& t_prev = prev_pitch_period;
+  const float& g_prev = prev_pitch_gain;
 
   // Validate input.
+  RTC_DCHECK_GE(t1, 0);
   RTC_DCHECK_GE(k, 2);
+  RTC_DCHECK_GE(t0, 0);
+  RTC_DCHECK_GE(t_prev, 0);
 
   // Compute a term that lowers the threshold when |t1| is close to the last
   // estimated period |t_prev| - i.e., pitch tracking.
   float lower_threshold_term = 0;
-  if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) <= 1) {
+  if (abs(t1 - t_prev) <= 1) {
     // The candidate pitch period is within 1 sample from the previous one.
     // Make the candidate at |t1| very easy to be accepted.
     lower_threshold_term = g_prev;
-  } else if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) == 2 &&
+  } else if (abs(t1 - t_prev) == 2 &&
              t0 > kInitialPitchPeriodThresholds[k - 2]) {
     // The candidate pitch period is 2 samples far from the previous one and the
     // period |t0| (from which |t1| has been derived) is greater than a
@@ -182,9 +185,11 @@
   // reduce the chance of false positives caused by a bias towards high
   // frequencies (originating from short-term correlations).
   float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
-  if (t1 < 3 * kMinPitch24kHz) {  // High frequency.
+  if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
+    // High frequency.
     threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
-  } else if (t1 < 2 * kMinPitch24kHz) {  // Even higher frequency.
+  } else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
+    // Even higher frequency.
     threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
   }
   return threshold;
@@ -350,16 +355,16 @@
 
 PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
     rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
-    size_t initial_pitch_period_48kHz,
+    int initial_pitch_period_48kHz,
     PitchInfo prev_pitch_48kHz) {
   RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
   RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
   // Stores information for a refined pitch candidate.
   struct RefinedPitchCandidate {
     RefinedPitchCandidate() {}
-    RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy)
+    RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
         : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
-    size_t period_24kHz;
+    int period_24kHz;
     // Pitch strength information.
     float gain;
     // Additional pitch strength information used for the final estimation of
@@ -380,8 +385,8 @@
   };
   // Initial pitch candidate gain.
   RefinedPitchCandidate best_pitch;
-  best_pitch.period_24kHz =
-      std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
+  best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
+                                     static_cast<int>(kMaxPitch24kHz - 1));
   best_pitch.xy = ComputeAutoCorrelationCoeff(
       pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
   best_pitch.yy = yy_values[best_pitch.period_24kHz];
@@ -392,24 +397,27 @@
   const float initial_pitch_gain = best_pitch.gain;
 
   // Given the initial pitch estimation, check lower periods (i.e., harmonics).
-  const auto alternative_period = [](size_t period, size_t k,
-                                     size_t n) -> size_t {
-    RTC_DCHECK_LT(0, k);
+  const auto alternative_period = [](int period, int k, int n) -> int {
+    RTC_DCHECK_GT(k, 0);
     return (2 * n * period + k) / (2 * k);  // Same as round(n*period/k).
   };
-  for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) {
-    size_t candidate_pitch_period =
-        alternative_period(initial_pitch_period, k, 1);
-    if (candidate_pitch_period < kMinPitch24kHz)
+  for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
+       ++k) {
+    int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
+    if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
       break;
+    }
     // When looking at |candidate_pitch_period|, we also look at one of its
     // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
     // |k| == 2 is a special case since |candidate_pitch_secondary_period| might
     // be greater than the maximum pitch period.
-    size_t candidate_pitch_secondary_period = alternative_period(
+    int candidate_pitch_secondary_period = alternative_period(
         initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
-    if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz)
+    RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
+    if (k == 2 &&
+        candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
       candidate_pitch_secondary_period = initial_pitch_period;
+    }
     RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
         << "The lower pitch period and the additional sub-harmonic must not "
         << "coincide.";
@@ -442,7 +450,7 @@
                                ? 1.f
                                : best_pitch.xy / (best_pitch.yy + 1.f);
   final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
-  size_t final_pitch_period_48kHz = std::max(
+  int final_pitch_period_48kHz = std::max(
       kMinPitch48kHz,
       PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
 
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
index bb747bb..aabf713 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
@@ -41,12 +41,12 @@
 // Computes a gain threshold for a candidate pitch period given the initial and
 // the previous pitch period and gain estimates and the pitch period ratio used
 // to derive the candidate pitch period from the initial period.
-float ComputePitchGainThreshold(size_t candidate_pitch_period,
-                                size_t pitch_period_ratio,
-                                size_t initial_pitch_period,
+float ComputePitchGainThreshold(int candidate_pitch_period,
+                                int pitch_period_ratio,
+                                int initial_pitch_period,
                                 float initial_pitch_gain,
-                                size_t prev_pitch_period,
-                                size_t prev_pitch_gain);
+                                int prev_pitch_period,
+                                float prev_pitch_gain);
 
 // Computes the sum of squared samples for every sliding frame in the pitch
 // buffer. |yy_values| indexes are lags.
@@ -99,7 +99,7 @@
 // refined pitch estimation data at 48 kHz.
 PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
     rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
-    size_t initial_pitch_period_48kHz,
+    int initial_pitch_period_48kHz,
     PitchInfo prev_pitch_48kHz);
 
 }  // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index 82b4810..033ea3e 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -24,8 +24,9 @@
 namespace test {
 namespace {
 
-constexpr std::array<size_t, 2> kTestPitchPeriods = {
-    3 * kMinPitch48kHz / 2, (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
+constexpr std::array<int, 2> kTestPitchPeriods = {
+    3 * kMinPitch48kHz / 2,
+    (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
 };
 constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
 
@@ -197,14 +198,14 @@
 class CheckLowerPitchPeriodsAndComputePitchGainTest
     : public testing::Test,
       public ::testing::WithParamInterface<
-          std::tuple<size_t, size_t, float, size_t, float>> {};
+          std::tuple<int, int, float, int, float>> {};
 
 TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
   const auto params = GetParam();
-  const size_t initial_pitch_period = std::get<0>(params);
-  const size_t prev_pitch_period = std::get<1>(params);
+  const int initial_pitch_period = std::get<0>(params);
+  const int prev_pitch_period = std::get<1>(params);
   const float prev_pitch_gain = std::get<2>(params);
-  const size_t expected_pitch_period = std::get<3>(params);
+  const int expected_pitch_period = std::get<3>(params);
   const float expected_pitch_gain = std::get<4>(params);
   TestData test_data;
   {
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
index 4c56238..eac332e 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
@@ -39,7 +39,7 @@
       lp_residual_reader.first->ReadValue(&expected_pitch_period);
       lp_residual_reader.first->ReadValue(&expected_pitch_gain);
       PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
-      EXPECT_EQ(static_cast<size_t>(expected_pitch_period), pitch_info.period);
+      EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
       EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
     }
   }