| /* |
| * Copyright (C) 2014, The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" |
| |
| #include <algorithm> |
| #include <cstring> |
| |
| #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" |
| #include "suggest/policyimpl/dictionary/utils/probability_utils.h" |
| |
| namespace latinime { |
| |
| const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; |
| const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; |
| |
| bool LanguageModelDictContent::save(FILE *const file) const { |
| return mTrieMap.save(file) && mGlobalCounters.save(file); |
| } |
| |
| bool LanguageModelDictContent::runGC( |
| const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, |
| const LanguageModelDictContent *const originalContent) { |
| return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(), |
| 0 /* nextLevelBitmapEntryIndex */); |
| } |
| |
| const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, |
| const int wordId, const bool mustMatchAllPrevWords, |
| const HeaderPolicy *const headerPolicy) const { |
| int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; |
| bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); |
| int maxPrevWordCount = 0; |
| for (size_t i = 0; i < prevWordIds.size(); ++i) { |
| const int nextBitmapEntryIndex = |
| mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; |
| if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { |
| break; |
| } |
| maxPrevWordCount = i + 1; |
| bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; |
| } |
| |
| const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); |
| if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) { |
| // The word should be treated as a invalid word. |
| return WordAttributes(); |
| } |
| for (int i = maxPrevWordCount; i >= 0; --i) { |
| if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) { |
| break; |
| } |
| const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); |
| if (!result.mIsValid) { |
| continue; |
| } |
| const ProbabilityEntry probabilityEntry = |
| ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); |
| int probability = NOT_A_PROBABILITY; |
| if (mHasHistoricalInfo) { |
| const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); |
| int contextCount = 0; |
| if (i == 0) { |
| // unigram |
| contextCount = mGlobalCounters.getTotalCount(); |
| } else { |
| const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( |
| prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); |
| if (!prevWordProbabilityEntry.isValid()) { |
| continue; |
| } |
| if (prevWordProbabilityEntry.representsBeginningOfSentence() |
| && historicalInfo->getCount() == 1) { |
| // BoS ngram requires multiple contextCount. |
| continue; |
| } |
| contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); |
| } |
| const float rawProbability = |
| DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( |
| historicalInfo->getCount(), contextCount, i + 1); |
| const int encodedRawProbability = |
| ProbabilityUtils::encodeRawProbability(rawProbability); |
| const int decayedProbability = |
| DynamicLanguageModelProbabilityUtils::getDecayedProbability( |
| encodedRawProbability, *historicalInfo); |
| probability = DynamicLanguageModelProbabilityUtils::backoff( |
| decayedProbability, i + 1 /* n */); |
| } else { |
| probability = probabilityEntry.getProbability(); |
| } |
| // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in |
| // probabilityEntry. |
| return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), |
| unigramProbabilityEntry.isNotAWord(), |
| unigramProbabilityEntry.isPossiblyOffensive()); |
| } |
| // Cannot find the word. |
| return WordAttributes(); |
| } |
| |
| ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( |
| const WordIdArrayView prevWordIds, const int wordId) const { |
| const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); |
| if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { |
| return ProbabilityEntry(); |
| } |
| const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex); |
| if (!result.mIsValid) { |
| // Not found. |
| return ProbabilityEntry(); |
| } |
| return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); |
| } |
| |
| bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, |
| const int wordId, const ProbabilityEntry *const probabilityEntry) { |
| if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) { |
| return false; |
| } |
| const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds); |
| if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { |
| return false; |
| } |
| return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); |
| } |
| |
| bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, |
| const int wordId) { |
| const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); |
| if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { |
| // Cannot find bitmap entry for the probability entry. The entry doesn't exist. |
| return false; |
| } |
| return mTrieMap.remove(wordId, bitmapEntryIndex); |
| } |
| |
| LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries( |
| const WordIdArrayView prevWordIds) const { |
| const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); |
| return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); |
| } |
| |
| std::vector<LanguageModelDictContent::DumppedFullEntryInfo> |
| LanguageModelDictContent::exportAllNgramEntriesRelatedToWord( |
| const HeaderPolicy *const headerPolicy, const int wordId) const { |
| const TrieMap::Result result = mTrieMap.getRoot(wordId); |
| if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) { |
| // The word doesn't have any related ngram entries. |
| return std::vector<DumppedFullEntryInfo>(); |
| } |
| std::vector<int> prevWordIds = { wordId }; |
| std::vector<DumppedFullEntryInfo> entries; |
| exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex, |
| &prevWordIds, &entries); |
| return entries; |
| } |
| |
| void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( |
| const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex, |
| std::vector<int> *const prevWordIds, |
| std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const { |
| for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { |
| const int wordId = entry.key(); |
| const ProbabilityEntry probabilityEntry = |
| ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); |
| if (probabilityEntry.isValid()) { |
| const WordAttributes wordAttributes = getWordAttributes( |
| WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */, |
| headerPolicy); |
| outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, |
| wordAttributes, probabilityEntry); |
| } |
| if (entry.hasNextLevelMap()) { |
| prevWordIds->push_back(wordId); |
| exportAllNgramEntriesRelatedToWordInner(headerPolicy, |
| entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo); |
| prevWordIds->pop_back(); |
| } |
| } |
| } |
| |
| bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts, |
| const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy, |
| MutableEntryCounters *const outEntryCounters) { |
| for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { |
| const int totalWordCount = prevWordCount + 1; |
| if (currentEntryCounts.getNgramCount(totalWordCount) |
| <= maxEntryCounts.getNgramCount(totalWordCount)) { |
| outEntryCounters->setNgramCount(totalWordCount, |
| currentEntryCounts.getNgramCount(totalWordCount)); |
| continue; |
| } |
| int entryCount = 0; |
| if (!turncateEntriesInSpecifiedLevel(headerPolicy, |
| maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) { |
| return false; |
| } |
| outEntryCounters->setNgramCount(totalWordCount, entryCount); |
| } |
| return true; |
| } |
| |
| bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, |
| const int wordId, const bool isValid, const HistoricalInfo historicalInfo, |
| const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) { |
| if (!mHasHistoricalInfo) { |
| AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info."); |
| return false; |
| } |
| const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId); |
| const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom( |
| originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy); |
| if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) { |
| return false; |
| } |
| mGlobalCounters.incrementTotalCount(); |
| mGlobalCounters.updateMaxValueOfCounters( |
| updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); |
| for (size_t i = 0; i < prevWordIds.size(); ++i) { |
| if (prevWordIds[i] == NOT_A_WORD_ID) { |
| break; |
| } |
| // TODO: Optimize this code. |
| const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1); |
| const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry( |
| limitedPrevWordIds, wordId); |
| const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom( |
| originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy); |
| if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) { |
| return false; |
| } |
| mGlobalCounters.updateMaxValueOfCounters( |
| updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); |
| if (!originalNgramProbabilityEntry.isValid()) { |
| entryCountersToUpdate->incrementNgramCount(i + 2); |
| } |
| } |
| return true; |
| } |
| |
| const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( |
| const ProbabilityEntry &originalProbabilityEntry, const bool isValid, |
| const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { |
| const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(), |
| 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount() |
| + historicalInfo.getCount()); |
| if (originalProbabilityEntry.isValid()) { |
| return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); |
| } else { |
| return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo); |
| } |
| } |
| |
| bool LanguageModelDictContent::runGCInner( |
| const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, |
| const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) { |
| for (auto &entry : trieMapRange) { |
| const auto it = terminalIdMap->find(entry.key()); |
| if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) { |
| // The word has been removed. |
| continue; |
| } |
| if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) { |
| return false; |
| } |
| if (entry.hasNextLevelMap()) { |
| if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(), |
| mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { |
| int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); |
| for (const int wordId : prevWordIds) { |
| const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex); |
| if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) { |
| lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; |
| continue; |
| } |
| if (!result.mIsValid) { |
| if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo), |
| lastBitmapEntryIndex)) { |
| AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId, |
| lastBitmapEntryIndex); |
| return TrieMap::INVALID_INDEX; |
| } |
| } |
| lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId, |
| lastBitmapEntryIndex); |
| } |
| return lastBitmapEntryIndex; |
| } |
| |
| int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { |
| int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); |
| for (const int wordId : prevWordIds) { |
| const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex); |
| if (!result.mIsValid) { |
| return TrieMap::INVALID_INDEX; |
| } |
| bitmapEntryIndex = result.mNextLevelBitmapEntryIndex; |
| } |
| return bitmapEntryIndex; |
| } |
| |
| bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, |
| const int prevWordCount, const HeaderPolicy *const headerPolicy, |
| const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) { |
| for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { |
| if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { |
| AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", |
| prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM); |
| return false; |
| } |
| const ProbabilityEntry probabilityEntry = |
| ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); |
| if (prevWordCount > 0 && probabilityEntry.isValid() |
| && !mTrieMap.getRoot(entry.key()).mIsValid) { |
| // The entry is related to a word that has been removed. Remove the entry. |
| if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { |
| return false; |
| } |
| continue; |
| } |
| if (mHasHistoricalInfo && probabilityEntry.isValid()) { |
| const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo(); |
| if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC( |
| *originalHistoricalInfo)) { |
| // Remove the entry. |
| if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { |
| return false; |
| } |
| continue; |
| } |
| if (needsToHalveCounters) { |
| const int updatedCount = originalHistoricalInfo->getCount() / 2; |
| if (updatedCount == 0) { |
| // Remove the entry. |
| if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { |
| return false; |
| } |
| continue; |
| } |
| const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(), |
| originalHistoricalInfo->getLevel(), updatedCount); |
| const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), |
| &historicalInfoToSave); |
| if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), |
| bitmapEntryIndex)) { |
| return false; |
| } |
| } |
| } |
| outEntryCounters->incrementNgramCount(prevWordCount + 1); |
| if (!entry.hasNextLevelMap()) { |
| continue; |
| } |
| if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), |
| prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( |
| const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, |
| int *const outEntryCount) { |
| std::vector<int> prevWordIds; |
| std::vector<EntryInfoToTurncate> entryInfoVector; |
| if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), |
| &prevWordIds, &entryInfoVector)) { |
| return false; |
| } |
| if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) { |
| *outEntryCount = static_cast<int>(entryInfoVector.size()); |
| return true; |
| } |
| *outEntryCount = maxEntryCount; |
| const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount; |
| std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, |
| entryInfoVector.end(), |
| EntryInfoToTurncate::Comparator()); |
| for (int i = 0; i < entryCountToRemove; ++i) { |
| const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; |
| if (!removeNgramProbabilityEntry( |
| WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, |
| const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, |
| std::vector<EntryInfoToTurncate> *const outEntryInfo) const { |
| const int prevWordCount = prevWordIds->size(); |
| for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { |
| if (prevWordCount < targetLevel) { |
| if (!entry.hasNextLevelMap()) { |
| continue; |
| } |
| prevWordIds->push_back(entry.key()); |
| if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(), |
| prevWordIds, outEntryInfo)) { |
| return false; |
| } |
| prevWordIds->pop_back(); |
| continue; |
| } |
| const ProbabilityEntry probabilityEntry = |
| ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); |
| const int priority = mHasHistoricalInfo |
| ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction( |
| *probabilityEntry.getHistoricalInfo()) |
| : probabilityEntry.getProbability(); |
| outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(), |
| entry.key(), targetLevel, prevWordIds->data()); |
| } |
| return true; |
| } |
| |
| bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( |
| const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { |
| if (left.mPriority != right.mPriority) { |
| return left.mPriority < right.mPriority; |
| } |
| if (left.mCount != right.mCount) { |
| return left.mCount < right.mCount; |
| } |
| if (left.mKey != right.mKey) { |
| return left.mKey < right.mKey; |
| } |
| if (left.mPrevWordCount != right.mPrevWordCount) { |
| return left.mPrevWordCount > right.mPrevWordCount; |
| } |
| for (int i = 0; i < left.mPrevWordCount; ++i) { |
| if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { |
| return left.mPrevWordIds[i] < right.mPrevWordIds[i]; |
| } |
| } |
| // left and rigth represent the same entry. |
| return false; |
| } |
| |
| LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority, |
| const int count, const int key, const int prevWordCount, const int *const prevWordIds) |
| : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) { |
| memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); |
| } |
| |
| } // namespace latinime |