blob: 136a9968940cef320472cd24dd1a894552960d74 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/pod_ner/utils.h"
#include <algorithm>
#include <iostream>
#include <unordered_map>
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
namespace libtextclassifier3 {
namespace {
// Returns true if the needle string is contained in the haystack.
bool StrIsOneOf(const std::string &needle,
const std::vector<std::string> &haystack) {
return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
}
// Finds the wordpiece span of the tokens in the given span.
WordpieceSpan CodepointSpanToWordpieceSpan(
const CodepointSpan &span, const std::vector<Token> &tokens,
const std::vector<int32_t> &word_starts, int num_wordpieces) {
int span_first_wordpiece_index = 0;
int span_last_wordpiece_index = num_wordpieces;
for (int i = 0; i < tokens.size(); i++) {
if (tokens[i].start <= span.first && span.first < tokens[i].end) {
span_first_wordpiece_index = word_starts[i];
}
if (tokens[i].start <= span.second && span.second <= tokens[i].end) {
span_last_wordpiece_index =
(i + 1) < word_starts.size() ? word_starts[i + 1] : num_wordpieces;
break;
}
}
return WordpieceSpan(span_first_wordpiece_index, span_last_wordpiece_index);
}
} // namespace
std::string SaftLabelToCollection(absl::string_view saft_label) {
return std::string(saft_label.substr(saft_label.rfind('/') + 1));
}
namespace internal {
int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
int num_wordpieces, int wordpiece_end) {
if (word_starts.empty()) {
return 0;
}
if (*word_starts.rbegin() < wordpiece_end &&
num_wordpieces <= wordpiece_end) {
// Last token.
return word_starts.size() - 1;
}
for (int i = word_starts.size() - 1; i > 0; --i) {
if (word_starts[i] <= wordpiece_end) {
return (i - 1);
}
}
return 0;
}
int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
int first_wordpiece_index) {
for (int i = 0; i < word_starts.size(); ++i) {
if (word_starts[i] == first_wordpiece_index) {
return i;
} else if (word_starts[i] > first_wordpiece_index) {
return std::max(0, i - 1);
}
}
return std::max(0, static_cast<int>(word_starts.size()) - 1);
}
WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
int num_wordpieces,
WordpieceSpan wordpiece_span_to_expand) {
if (wordpiece_span_to_expand.length() >= max_num_wordpieces_in_window) {
return wordpiece_span_to_expand;
}
int window_first_wordpiece_index = std::max(
0, wordpiece_span_to_expand.begin - ((max_num_wordpieces_in_window -
wordpiece_span_to_expand.length()) /
2));
if ((window_first_wordpiece_index + max_num_wordpieces_in_window) >
num_wordpieces) {
window_first_wordpiece_index =
std::max(num_wordpieces - max_num_wordpieces_in_window, 0);
}
return WordpieceSpan(
window_first_wordpiece_index,
std::min(window_first_wordpiece_index + max_num_wordpieces_in_window,
num_wordpieces));
}
WordpieceSpan FindWordpiecesWindowAroundSpan(
const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
const std::vector<int32_t> &word_starts, int num_wordpieces,
int max_num_wordpieces_in_window) {
WordpieceSpan wordpiece_span_to_expand = CodepointSpanToWordpieceSpan(
span_of_interest, tokens, word_starts, num_wordpieces);
WordpieceSpan max_wordpiece_span = ExpandWindowAndAlign(
max_num_wordpieces_in_window, num_wordpieces, wordpiece_span_to_expand);
return max_wordpiece_span;
}
WordpieceSpan FindFullTokensSpanInWindow(
const std::vector<int32_t> &word_starts,
const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
int num_wordpieces, int *first_token_index, int *num_tokens) {
int window_first_wordpiece_index = wordpiece_span.begin;
*first_token_index = internal::FindFirstFullTokenIndex(
word_starts, window_first_wordpiece_index);
window_first_wordpiece_index = word_starts[*first_token_index];
// Need to update the last index in case the first moved backward.
int wordpiece_window_end = std::min(
wordpiece_span.end, window_first_wordpiece_index + max_num_wordpieces);
int last_token_index;
last_token_index = internal::FindLastFullTokenIndex(
word_starts, num_wordpieces, wordpiece_window_end);
wordpiece_window_end = last_token_index == (word_starts.size() - 1)
? num_wordpieces
: word_starts[last_token_index + 1];
*num_tokens = last_token_index - *first_token_index + 1;
return WordpieceSpan(window_first_wordpiece_index, wordpiece_window_end);
}
} // namespace internal
WindowGenerator::WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
const std::vector<int32_t> &token_starts,
const std::vector<Token> &tokens,
int max_num_wordpieces,
int sliding_window_overlap,
const CodepointSpan &span_of_interest)
: wordpiece_indices_(&wordpiece_indices),
token_starts_(&token_starts),
tokens_(&tokens),
max_num_effective_wordpieces_(max_num_wordpieces),
sliding_window_num_wordpieces_overlap_(sliding_window_overlap) {
entire_wordpiece_span_ = internal::FindWordpiecesWindowAroundSpan(
span_of_interest, tokens, token_starts, wordpiece_indices.size(),
max_num_wordpieces);
next_wordpiece_span_ = WordpieceSpan(
entire_wordpiece_span_.begin,
std::min(entire_wordpiece_span_.begin + max_num_effective_wordpieces_,
entire_wordpiece_span_.end));
previous_wordpiece_span_ = WordpieceSpan(-1, -1);
}
bool WindowGenerator::Next(VectorSpan<int32_t> *cur_wordpiece_indices,
VectorSpan<int32_t> *cur_token_starts,
VectorSpan<Token> *cur_tokens) {
if (Done()) {
return false;
}
// Update the span to cover full tokens.
int cur_first_token_index, cur_num_tokens;
next_wordpiece_span_ = internal::FindFullTokensSpanInWindow(
*token_starts_, next_wordpiece_span_, max_num_effective_wordpieces_,
wordpiece_indices_->size(), &cur_first_token_index, &cur_num_tokens);
*cur_token_starts = VectorSpan<int32_t>(
token_starts_->begin() + cur_first_token_index,
token_starts_->begin() + cur_first_token_index + cur_num_tokens);
*cur_tokens = VectorSpan<Token>(
tokens_->begin() + cur_first_token_index,
tokens_->begin() + cur_first_token_index + cur_num_tokens);
// Handle the edge case where the tokens are composed of many wordpieces and
// the window doesn't advance.
if (next_wordpiece_span_.begin <= previous_wordpiece_span_.begin ||
next_wordpiece_span_.end <= previous_wordpiece_span_.end) {
return false;
}
previous_wordpiece_span_ = next_wordpiece_span_;
int next_wordpiece_first = std::max(
previous_wordpiece_span_.end - sliding_window_num_wordpieces_overlap_,
previous_wordpiece_span_.begin + 1);
next_wordpiece_span_ = WordpieceSpan(
next_wordpiece_first,
std::min(next_wordpiece_first + max_num_effective_wordpieces_,
entire_wordpiece_span_.end));
*cur_wordpiece_indices = VectorSpan<int>(
wordpiece_indices_->begin() + previous_wordpiece_span_.begin,
wordpiece_indices_->begin() + previous_wordpiece_span_.begin +
previous_wordpiece_span_.length());
return true;
}
bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
const std::vector<std::string> &tags,
const std::vector<std::string> &label_filter,
bool relaxed_inside_label_matching,
bool relaxed_label_category_matching,
float priority_score,
std::vector<AnnotatedSpan> *results) {
AnnotatedSpan current_span;
std::string current_tag_type;
if (tags.size() > tokens.size()) {
return false;
}
for (int i = 0; i < tags.size(); i++) {
if (tags[i].empty()) {
return false;
}
std::vector<absl::string_view> tag_parts = absl::StrSplit(tags[i], '-');
TC3_CHECK_GT(tag_parts.size(), 0);
if (tag_parts[0].size() != 1) {
return false;
}
std::string tag_type = "";
if (tag_parts.size() > 2) {
// Skip if the current label doesn't match the filter.
if (!StrIsOneOf(std::string(tag_parts[1]), label_filter)) {
current_tag_type = "";
current_span = {};
continue;
}
// Relax the matching of the label category if specified.
tag_type = relaxed_label_category_matching
? std::string(tag_parts[2])
: absl::StrCat(tag_parts[1], "-", tag_parts[2]);
}
switch (tag_parts[0][0]) {
case 'S': {
if (tag_parts.size() != 3) {
return false;
}
current_span = {};
current_tag_type = "";
results->push_back(AnnotatedSpan{
{tokens[i].start, tokens[i].end},
{{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
/*arg_score=*/1.0, priority_score}}});
break;
};
case 'B': {
if (tag_parts.size() != 3) {
return false;
}
current_tag_type = tag_type;
current_span = {};
current_span.classification.push_back(
{/*arg_collection=*/SaftLabelToCollection(tag_parts[2]),
/*arg_score=*/1.0, priority_score});
current_span.span.first = tokens[i].start;
break;
};
case 'I': {
if (tag_parts.size() != 3) {
return false;
}
if (!relaxed_inside_label_matching && current_tag_type != tag_type) {
current_tag_type = "";
current_span = {};
}
break;
}
case 'E': {
if (tag_parts.size() != 3) {
return false;
}
if (!current_tag_type.empty() && current_tag_type == tag_type) {
current_span.span.second = tokens[i].end;
results->push_back(current_span);
current_span = {};
current_tag_type = "";
}
break;
};
case 'O': {
current_tag_type = "";
current_span = {};
break;
};
default: {
TC3_LOG(ERROR) << "Unrecognized tag: " << tags[i];
return false;
}
}
}
return true;
}
using PodNerModel_::CollectionT;
using PodNerModel_::LabelT;
using PodNerModel_::Label_::BoiseType;
using PodNerModel_::Label_::MentionType;
bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
const std::vector<LabelT> &labels,
const std::vector<CollectionT> &collections,
const std::vector<MentionType> &mention_filter,
bool relaxed_inside_label_matching,
bool relaxed_mention_type_matching,
std::vector<AnnotatedSpan> *results) {
if (labels.size() > tokens.size()) {
return false;
}
AnnotatedSpan current_span;
std::string current_collection_name = "";
for (int i = 0; i < labels.size(); i++) {
const LabelT &label = labels[i];
if (label.collection_id < 0 || label.collection_id >= collections.size()) {
return false;
}
if (std::find(mention_filter.begin(), mention_filter.end(),
label.mention_type) == mention_filter.end()) {
// Skip if the current label doesn't match the filter.
current_span = {};
current_collection_name = "";
continue;
}
switch (label.boise_type) {
case BoiseType::BoiseType_SINGLE: {
current_span = {};
current_collection_name = "";
results->push_back(AnnotatedSpan{
{tokens[i].start, tokens[i].end},
{{/*arg_collection=*/collections[label.collection_id].name,
/*arg_score=*/1.0,
collections[label.collection_id].single_token_priority_score}}});
break;
};
case BoiseType::BoiseType_BEGIN: {
current_span = {};
current_span.classification.push_back(
{/*arg_collection=*/collections[label.collection_id].name,
/*arg_score=*/1.0,
collections[label.collection_id].multi_token_priority_score});
current_span.span.first = tokens[i].start;
current_collection_name = collections[label.collection_id].name;
break;
};
case BoiseType::BoiseType_INTERMEDIATE: {
if (current_collection_name.empty() ||
(!relaxed_mention_type_matching &&
labels[i - 1].mention_type != label.mention_type) ||
(!relaxed_inside_label_matching &&
labels[i - 1].collection_id != label.collection_id)) {
current_span = {};
current_collection_name = "";
}
break;
}
case BoiseType::BoiseType_END: {
if (!current_collection_name.empty() &&
current_collection_name == collections[label.collection_id].name &&
(relaxed_mention_type_matching ||
labels[i - 1].mention_type == label.mention_type)) {
current_span.span.second = tokens[i].end;
results->push_back(current_span);
}
current_span = {};
current_collection_name = "";
break;
};
case BoiseType::BoiseType_O: {
current_span = {};
current_collection_name = "";
break;
};
default: {
TC3_LOG(ERROR) << "Unrecognized tag: " << labels[i].boise_type;
return false;
}
}
}
return true;
}
bool MergeLabelsIntoLeftSequence(
const std::vector<PodNerModel_::LabelT> &labels_right,
int index_first_right_tag_in_left,
std::vector<PodNerModel_::LabelT> *labels_left) {
if (index_first_right_tag_in_left > labels_left->size()) {
return false;
}
int overlaping_from_left =
(labels_left->size() - index_first_right_tag_in_left) / 2;
labels_left->resize(index_first_right_tag_in_left + labels_right.size());
std::copy(labels_right.begin() + overlaping_from_left, labels_right.end(),
labels_left->begin() + index_first_right_tag_in_left +
overlaping_from_left);
return true;
}
} // namespace libtextclassifier3