Merge "Update useless n-gram entry detection logic during GC."
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index a7296a3..c4297f5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -270,16 +270,26 @@
 }
 
 bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
-        const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
+        const int prevWordCount, const HeaderPolicy *const headerPolicy,
+        int *const outEntryCounts) {
     for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
-        if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
-            AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
-                    level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
+        if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+            AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
+                    prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
             return false;
         }
         const ProbabilityEntry probabilityEntry =
                 ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
-        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
+        if (prevWordCount > 0 && probabilityEntry.isValid()
+                && !mTrieMap.getRoot(entry.key()).mIsValid) {
+            // The entry is related to a word that has been removed. Remove the entry.
+            if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+                return false;
+            }
+            continue;
+        }
+        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()
+                && probabilityEntry.isValid()) {
             const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                     probabilityEntry.getHistoricalInfo(), headerPolicy);
             if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
@@ -298,13 +308,13 @@
             }
         }
         if (!probabilityEntry.representsBeginningOfSentence()) {
-            outEntryCounts[level] += 1;
+            outEntryCounts[prevWordCount] += 1;
         }
         if (!entry.hasNextLevelMap()) {
             continue;
         }
-        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
-                headerPolicy, outEntryCounts)) {
+        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
+                prevWordCount + 1, headerPolicy, outEntryCounts)) {
             return false;
         }
     }
@@ -332,7 +342,7 @@
     for (int i = 0; i < entryCountToRemove; ++i) {
         const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
         if (!removeNgramProbabilityEntry(
-                WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
+                WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) {
             return false;
         }
     }
@@ -342,9 +352,9 @@
 bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
         const int targetLevel, const int bitmapEntryIndex,  std::vector<int> *const prevWordIds,
         std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
-    const int currentLevel = prevWordIds->size();
+    const int prevWordCount = prevWordIds->size();
     for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
-        if (currentLevel < targetLevel) {
+        if (prevWordCount < targetLevel) {
             if (!entry.hasNextLevelMap()) {
                 continue;
             }
@@ -379,10 +389,10 @@
     if (left.mKey != right.mKey) {
         return left.mKey < right.mKey;
     }
-    if (left.mEntryLevel != right.mEntryLevel) {
-        return left.mEntryLevel > right.mEntryLevel;
+    if (left.mPrevWordCount != right.mPrevWordCount) {
+        return left.mPrevWordCount > right.mPrevWordCount;
     }
-    for (int i = 0; i < left.mEntryLevel; ++i) {
+    for (int i = 0; i < left.mPrevWordCount; ++i) {
         if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
             return left.mPrevWordIds[i] < right.mPrevWordIds[i];
         }
@@ -392,9 +402,10 @@
 }
 
 LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
-        const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
-        : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
-    memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
+        const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds)
+        : mProbability(probability), mTimestamp(timestamp), mKey(key),
+          mPrevWordCount(prevWordCount) {
+    memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
 }
 
 } // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 834cf93..51ef090 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -160,7 +160,7 @@
             outEntryCounts[i] = 0;
         }
         return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
-                0 /* level */, headerPolicy, outEntryCounts);
+                0 /* prevWordCount */, headerPolicy, outEntryCounts);
     }
 
     // entryCounts should be created by updateAllProbabilityEntries.
@@ -185,12 +185,12 @@
         };
 
         EntryInfoToTurncate(const int probability, const int timestamp, const int key,
-                const int entryLevel, const int *const prevWordIds);
+                const int prevWordCount, const int *const prevWordIds);
 
         int mProbability;
         int mTimestamp;
         int mKey;
-        int mEntryLevel;
+        int mPrevWordCount;
         int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
 
      private:
@@ -208,7 +208,7 @@
             int *const outNgramCount);
     int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
     int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
-    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int level,
+    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
             const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
     bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
             const int maxEntryCount, const int targetLevel, int *const outEntryCount);
diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
index 0e58b72..fa70f99 100644
--- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
+++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
@@ -75,6 +75,10 @@
         return formatVersion > FormatSpec.VERSION401;
     }
 
+    private static boolean supportsNgram(final int formatVersion) {
+        return formatVersion >= FormatSpec.VERSION4_DEV;
+    }
+
     private void onInputWord(final BinaryDictionary binaryDictionary, final String word,
             final boolean isValidWord) {
         binaryDictionary.updateEntriesForWordWithNgramContext(NgramContext.EMPTY_PREV_WORDS_INFO,
@@ -88,6 +92,14 @@
                 mCurrentTime /* timestamp */);
     }
 
+    private void onInputWordWithPrevWords(final BinaryDictionary binaryDictionary,
+            final String word, final boolean isValidWord, final String prevWord,
+            final String prevPrevWord) {
+        binaryDictionary.updateEntriesForWordWithNgramContext(
+                new NgramContext(new WordInfo(prevWord), new WordInfo(prevPrevWord)), word,
+                isValidWord, 1 /* count */, mCurrentTime /* timestamp */);
+    }
+
     private void onInputWordWithBeginningOfSentenceContext(
             final BinaryDictionary binaryDictionary, final String word, final boolean isValidWord) {
         binaryDictionary.updateEntriesForWordWithNgramContext(NgramContext.BEGINNING_OF_SENTENCE,
@@ -99,6 +111,12 @@
         return binaryDictionary.isValidNgram(new NgramContext(new WordInfo(word0)), word1);
     }
 
+    private static boolean isValidTrigram(final BinaryDictionary binaryDictionary,
+            final String word0, final String word1, final String word2) {
+        return binaryDictionary.isValidNgram(
+                new NgramContext(new WordInfo(word1), new WordInfo(word0)), word2);
+    }
+
     private void forcePassingShortTime(final BinaryDictionary binaryDictionary) {
         // 30 days.
         final int timeToElapse = (int)TimeUnit.SECONDS.convert(30, TimeUnit.DAYS);
@@ -256,7 +274,23 @@
         onInputWordWithPrevWord(binaryDictionary, "y", true /* isValidWord */, "x");
         assertFalse(isValidBigram(binaryDictionary, "x", "y"));
 
-        binaryDictionary.close();
+        if (!supportsNgram(formatVersion)) {
+            return;
+        }
+
+        onInputWordWithPrevWords(binaryDictionary, "c", false /* isValidWord */, "b", "a");
+        assertFalse(isValidTrigram(binaryDictionary, "a", "b", "c"));
+        assertFalse(isValidBigram(binaryDictionary, "b", "c"));
+        onInputWordWithPrevWords(binaryDictionary, "c", false /* isValidWord */, "b", "a");
+        assertTrue(isValidTrigram(binaryDictionary, "a", "b", "c"));
+        assertTrue(isValidBigram(binaryDictionary, "b", "c"));
+
+        onInputWordWithPrevWords(binaryDictionary, "d", true /* isValidWord */, "b", "a");
+        assertTrue(isValidTrigram(binaryDictionary, "a", "b", "d"));
+        assertTrue(isValidBigram(binaryDictionary, "b", "d"));
+
+        onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "b", "a");
+        assertTrue(isValidTrigram(binaryDictionary, "a", "b", "cd"));
     }
 
     public void testDecayingProbability() {
@@ -301,6 +335,31 @@
         forcePassingLongTime(binaryDictionary);
         assertFalse(isValidBigram(binaryDictionary, "a", "b"));
 
+        if (!supportsNgram(formatVersion)) {
+            return;
+        }
+
+        onInputWord(binaryDictionary, "ab", true /* isValidWord */);
+        onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab");
+        onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab");
+        assertTrue(isValidTrigram(binaryDictionary, "ab", "bc", "cd"));
+        forcePassingShortTime(binaryDictionary);
+        assertFalse(isValidTrigram(binaryDictionary, "ab", "bc", "cd"));
+
+        onInputWord(binaryDictionary, "ab", true /* isValidWord */);
+        onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab");
+        onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab");
+        onInputWord(binaryDictionary, "ab", true /* isValidWord */);
+        onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab");
+        onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab");
+        onInputWord(binaryDictionary, "ab", true /* isValidWord */);
+        onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab");
+        onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab");
+        forcePassingShortTime(binaryDictionary);
+        assertTrue(isValidTrigram(binaryDictionary, "ab", "bc", "cd"));
+        forcePassingLongTime(binaryDictionary);
+        assertFalse(isValidTrigram(binaryDictionary, "ab", "bc", "cd"));
+
         binaryDictionary.close();
     }