Support ngram entry migration.

Bug: 14425059
Change-Id: I98cb9fa303af2d93a0a3512e8732231c564e3c5d
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
index e72226c..461d1d8 100644
--- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
+++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
@@ -629,8 +629,7 @@
         }
     } while (token != 0);
 
-    // Add bigrams.
-    // TODO: Support ngrams.
+    // Add ngrams.
     do {
         token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
         const WordProperty wordProperty = dictionary->getWordProperty(
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 66cb051..08e39ce 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -580,10 +580,12 @@
                     getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
                     bigramWord1CodePoints);
             const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
-            const int probability = bigramEntry.hasHistoricalInfo() ?
-                    ForgettingCurveUtils::decodeProbability(
-                            bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
-                    bigramEntry.getProbability();
+            const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
+                    ? ForgettingCurveUtils::decodeProbability(
+                            bigramEntry.getHistoricalInfo(), mHeaderPolicy)
+                    : bigramEntry.getProbability();
+            const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
+                    ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
             ngrams.emplace_back(
                     NgramContext(wordCodePoints.data(), wordCodePoints.size(),
                             ptNodeParams.representsBeginningOfSentence()),
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index a889965..b962904 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -140,6 +140,44 @@
     return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
 }
 
+std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
+        LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
+                const HeaderPolicy *const headerPolicy, const int wordId) const {
+    const TrieMap::Result result = mTrieMap.getRoot(wordId);
+    if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+        // The word doesn't have any related ngram entries.
+        return std::vector<DumppedFullEntryInfo>();
+    }
+    std::vector<int> prevWordIds = { wordId };
+    std::vector<DumppedFullEntryInfo> entries;
+    exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
+            &prevWordIds, &entries);
+    return entries;
+}
+
+void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
+        const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
+        std::vector<int> *const prevWordIds,
+        std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
+    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+        const int wordId = entry.key();
+        const ProbabilityEntry probabilityEntry =
+                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+        if (probabilityEntry.isValid()) {
+            const WordAttributes wordAttributes = getWordAttributes(
+                    WordIdArrayView(*prevWordIds), wordId, headerPolicy);
+            outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
+                    wordAttributes, probabilityEntry);
+        }
+        if (entry.hasNextLevelMap()) {
+            prevWordIds->push_back(wordId);
+            exportAllNgramEntriesRelatedToWordInner(headerPolicy,
+                    entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
+            prevWordIds->pop_back();
+        }
+    }
+}
+
 bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
         const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
         MutableEntryCounters *const outEntryCounters) {
@@ -231,24 +269,25 @@
 }
 
 int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
-    if (prevWordIds.empty()) {
-        return mTrieMap.getRootBitmapEntryIndex();
-    }
-    const int lastBitmapEntryIndex =
-            getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
-    if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
-        return TrieMap::INVALID_INDEX;
-    }
-    const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
-    const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
-    if (!result.mIsValid) {
-        if (!mTrieMap.put(oldestPrevWordId,
-                ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
-            return TrieMap::INVALID_INDEX;
+    int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
+    for (const int wordId : prevWordIds) {
+        const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
+        if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
+            lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
+            continue;
         }
+        if (!result.mIsValid) {
+            if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
+                    lastBitmapEntryIndex)) {
+                AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
+                        lastBitmapEntryIndex);
+                return TrieMap::INVALID_INDEX;
+            }
+        }
+        lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
+                lastBitmapEntryIndex);
     }
-    return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
-            lastBitmapEntryIndex);
+    return lastBitmapEntryIndex;
 }
 
 int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 41a429a..1cccf92 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -110,6 +110,27 @@
         const bool mHasHistoricalInfo;
     };
 
+    class DumppedFullEntryInfo {
+     public:
+        DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
+                const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
+                : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
+                  mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
+
+        const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
+        int getTargetWordId() const { return mTargetWordId; }
+        const WordAttributes &getWordAttributes() const { return mWordAttributes; }
+        const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
+
+     private:
+        DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
+
+        const std::vector<int> mPrevWordIds;
+        const int mTargetWordId;
+        const WordAttributes mWordAttributes;
+        const ProbabilityEntry mProbabilityEntry;
+    };
+
     LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
             const bool hasHistoricalInfo)
             : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@@ -151,6 +172,9 @@
 
     EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
 
+    std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
+            const HeaderPolicy *const headerPolicy, const int wordId) const;
+
     bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
             MutableEntryCounters *const outEntryCounters) {
         return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
@@ -212,6 +236,9 @@
     const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
             const bool isValid, const HistoricalInfo historicalInfo,
             const HeaderPolicy *const headerPolicy) const;
+    void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
+            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+            std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
 };
 } // namespace latinime
 #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 1669375..193326d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -491,30 +491,37 @@
     const int ptNodePos =
             mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
     const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
-    const ProbabilityEntry probabilityEntry =
-            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
-                    ptNodeParams.getTerminalId());
-    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
-    // Fetch bigram information.
-    // TODO: Support n-gram.
+    const LanguageModelDictContent *const languageModelDictContent =
+            mBuffers->getLanguageModelDictContent();
+    // Fetch ngram information.
     std::vector<NgramProperty> ngrams;
-    const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
-    int bigramWord1CodePoints[MAX_WORD_LENGTH];
-    for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
-            prevWordIds)) {
-        const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
-                MAX_WORD_LENGTH, bigramWord1CodePoints);
+    int ngramTargetCodePoints[MAX_WORD_LENGTH];
+    int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
+    int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+    bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+    for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
+            mHeaderPolicy, wordId)) {
+        const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
+                MAX_WORD_LENGTH, ngramTargetCodePoints);
+        const WordIdArrayView prevWordIds = entry.getPrevWordIds();
+        for (size_t i = 0; i < prevWordIds.size(); ++i) {
+            ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
+                       MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
+            ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
+                    prevWordIds[i]).representsBeginningOfSentence();
+            if (ngramPrevWordIsBeginningOfSentense[i]) {
+                ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
+                        ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
+            }
+        }
+        const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
+                ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
         const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
         const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
-        const int probability = ngramProbabilityEntry.hasHistoricalInfo() ?
-                ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
-                ngramProbabilityEntry.getProbability();
-        ngrams.emplace_back(
-                NgramContext(
-                        wordCodePoints.data(), wordCodePoints.size(),
-                        probabilityEntry.representsBeginningOfSentence()),
-                CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
-                probability, *historicalInfo);
+        // TODO: Output flags in WordAttributes.
+        ngrams.emplace_back(ngramContext,
+                CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
+                entry.getWordAttributes().getProbability(), *historicalInfo);
     }
     // Fetch shortcut information.
     std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@@ -534,6 +541,9 @@
                     shortcutProbability);
         }
     }
+    const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
+            ptNodeParams.getTerminalId());
+    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
     const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
             probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
             probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h
index 5e9cdd9..7871c26 100644
--- a/native/jni/src/utils/char_utils.h
+++ b/native/jni/src/utils/char_utils.h
@@ -101,6 +101,17 @@
         return codePointCount + 1;
     }
 
+    // Returns updated code point count.
+    static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints,
+            const int codePointCount) {
+        if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) {
+            return codePointCount;
+        }
+        const int newCodePointCount = codePointCount - 1;
+        memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount);
+        return newCodePointCount;
+    }
+
  private:
     DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);
 
diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
index 83e523c..9eb03e6 100644
--- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
+++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
@@ -653,6 +653,13 @@
         assertFalse(binaryDictionary.isValidWord("bbb"));
         assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
 
+        if (supportsNgram(toFormatVersion)) {
+            onInputWordWithPrevWords(binaryDictionary, "xyz", true, "abc", "aaa");
+            assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
+            onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
+            assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
+        }
+
         assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
         assertTrue(binaryDictionary.migrateTo(toFormatVersion));
         assertTrue(binaryDictionary.isValidDictionary());
@@ -666,6 +673,14 @@
         assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
         onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
         assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
+
+        if (supportsNgram(toFormatVersion)) {
+            assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
+            assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
+            onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
+            assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
+        }
+
         binaryDictionary.close();
     }