Sync of libtextclassifier from Google3.
Exported by: knowledge/cerebra/sense/text_classifier/lib/export_to_aosp.sh
Bug: 67618889
Test: Builds. Tested also with oc-mr1 and tested that smartselect/sharing features work.
Change-Id: I25ad82cdd5eed20c60e83e7eb94dae6ab08b3690
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 2c64b67..a39a789 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -20,6 +20,7 @@
#define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
#include <memory>
+#include <set>
#include <string>
#include <vector>
@@ -104,6 +105,7 @@
{options.internal_tokenizer_codepoint_ranges().begin(),
options.internal_tokenizer_codepoint_ranges().end()},
&internal_tokenizer_codepoint_ranges_);
+ PrepareIgnoredSpanBoundaryCodepoints();
}
explicit FeatureProcessor(const std::string& serialized_options)
@@ -137,6 +139,8 @@
// Extracts features as a CachedFeatures object that can be used for repeated
// inference over token spans in the given context.
+ // When relative_click_span == {kInvalidIndex, kInvalidIndex} then all tokens
+ // extracted from context will be considered.
bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
TokenSpan relative_click_span,
const FeatureVectorFn& feature_vector_fn,
@@ -155,6 +159,12 @@
return feature_extractor_.DenseFeaturesCount();
}
+ // Strips boundary codepoints from the span in context and returns the new
+ // start and end indices. If the span comprises entirely of boundary
+ // codepoints, the first index of span is returned for both indices.
+ CodepointSpan StripBoundaryCodepoints(const std::string& context,
+ CodepointSpan span) const;
+
protected:
// Represents a codepoint range [start, end).
struct CodepointRange {
@@ -207,6 +217,18 @@
bool IsCodepointInRanges(
int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
+ void PrepareIgnoredSpanBoundaryCodepoints();
+
+ // Counts the number of span boundary codepoints. If count_from_beginning is
+ // True, the counting will start at the span_start iterator (inclusive) and at
+ // maximum end at span_end (exclusive). If count_from_beginning is True, the
+ // counting will start from span_end (exclusive) and end at span_start
+ // (inclusive).
+ int CountIgnoredSpanBoundaryCodepoints(
+ const UnicodeText::const_iterator& span_start,
+ const UnicodeText::const_iterator& span_end,
+ bool count_from_beginning) const;
+
// Finds the center token index in tokens vector, using the method defined
// in options_.
int FindCenterToken(CodepointSpan span,
@@ -240,6 +262,10 @@
std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
private:
+ // Set of codepoints that will be stripped from beginning and end of
+ // predicted spans.
+ std::set<int32> ignored_span_boundary_codepoints_;
+
const FeatureProcessorOptions options_;
// Mapping between token selection spans and labels ids.