Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2017 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 17 | #ifndef LIBTEXTCLASSIFIER_TYPES_H_ |
| 18 | #define LIBTEXTCLASSIFIER_TYPES_H_ |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 19 | |
| 20 | #include <algorithm> |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 21 | #include <cmath> |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 22 | #include <functional> |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 23 | #include <set> |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 24 | #include <string> |
| 25 | #include <utility> |
| 26 | #include <vector> |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 27 | #include "util/base/integral_types.h" |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 28 | |
| 29 | #include "util/base/logging.h" |
| 30 | |
| 31 | namespace libtextclassifier2 { |
| 32 | |
| 33 | constexpr int kInvalidIndex = -1; |
| 34 | |
| 35 | // Index for a 0-based array of tokens. |
| 36 | using TokenIndex = int; |
| 37 | |
| 38 | // Index for a 0-based array of codepoints. |
| 39 | using CodepointIndex = int; |
| 40 | |
| 41 | // Marks a span in a sequence of codepoints. The first element is the index of |
| 42 | // the first codepoint of the span, and the second element is the index of the |
| 43 | // codepoint one past the end of the span. |
| 44 | // TODO(b/71982294): Make it a struct. |
| 45 | using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; |
| 46 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 47 | inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { |
| 48 | return a.first < b.second && b.first < a.second; |
| 49 | } |
| 50 | |
Lukas Zilka | df710db | 2018-02-27 12:44:09 +0100 | [diff] [blame] | 51 | inline bool ValidNonEmptySpan(const CodepointSpan& span) { |
| 52 | return span.first < span.second && span.first >= 0 && span.second >= 0; |
| 53 | } |
| 54 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 55 | template <typename T> |
| 56 | bool DoesCandidateConflict( |
| 57 | const int considered_candidate, const std::vector<T>& candidates, |
| 58 | const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) { |
| 59 | if (chosen_indices_set.empty()) { |
| 60 | return false; |
| 61 | } |
| 62 | |
| 63 | auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate); |
| 64 | // Check conflict on the right. |
| 65 | if (conflicting_it != chosen_indices_set.end() && |
| 66 | SpansOverlap(candidates[considered_candidate].span, |
| 67 | candidates[*conflicting_it].span)) { |
| 68 | return true; |
| 69 | } |
| 70 | |
| 71 | // Check conflict on the left. |
| 72 | // If we can't go more left, there can't be a conflict: |
| 73 | if (conflicting_it == chosen_indices_set.begin()) { |
| 74 | return false; |
| 75 | } |
| 76 | // Otherwise move one span left and insert if it doesn't overlap with the |
| 77 | // candidate. |
| 78 | --conflicting_it; |
| 79 | if (!SpansOverlap(candidates[considered_candidate].span, |
| 80 | candidates[*conflicting_it].span)) { |
| 81 | return false; |
| 82 | } |
| 83 | |
| 84 | return true; |
| 85 | } |
| 86 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 87 | // Marks a span in a sequence of tokens. The first element is the index of the |
| 88 | // first token in the span, and the second element is the index of the token one |
| 89 | // past the end of the span. |
| 90 | // TODO(b/71982294): Make it a struct. |
| 91 | using TokenSpan = std::pair<TokenIndex, TokenIndex>; |
| 92 | |
| 93 | // Returns the size of the token span. Assumes that the span is valid. |
| 94 | inline int TokenSpanSize(const TokenSpan& token_span) { |
| 95 | return token_span.second - token_span.first; |
| 96 | } |
| 97 | |
| 98 | // Returns a token span consisting of one token. |
| 99 | inline TokenSpan SingleTokenSpan(int token_index) { |
| 100 | return {token_index, token_index + 1}; |
| 101 | } |
| 102 | |
| 103 | // Returns an intersection of two token spans. Assumes that both spans are valid |
| 104 | // and overlapping. |
| 105 | inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1, |
| 106 | const TokenSpan& token_span2) { |
| 107 | return {std::max(token_span1.first, token_span2.first), |
| 108 | std::min(token_span1.second, token_span2.second)}; |
| 109 | } |
| 110 | |
| 111 | // Returns and expanded token span by adding a certain number of tokens on its |
| 112 | // left and on its right. |
| 113 | inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span, |
| 114 | int num_tokens_left, int num_tokens_right) { |
| 115 | return {token_span.first - num_tokens_left, |
| 116 | token_span.second + num_tokens_right}; |
| 117 | } |
| 118 | |
| 119 | // Token holds a token, its position in the original string and whether it was |
| 120 | // part of the input span. |
| 121 | struct Token { |
| 122 | std::string value; |
| 123 | CodepointIndex start; |
| 124 | CodepointIndex end; |
| 125 | |
| 126 | // Whether the token is a padding token. |
| 127 | bool is_padding; |
| 128 | |
| 129 | // Default constructor constructs the padding-token. |
| 130 | Token() |
| 131 | : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} |
| 132 | |
| 133 | Token(const std::string& arg_value, CodepointIndex arg_start, |
| 134 | CodepointIndex arg_end) |
| 135 | : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} |
| 136 | |
| 137 | bool operator==(const Token& other) const { |
| 138 | return value == other.value && start == other.start && end == other.end && |
| 139 | is_padding == other.is_padding; |
| 140 | } |
| 141 | |
| 142 | bool IsContainedInSpan(CodepointSpan span) const { |
| 143 | return start >= span.first && end <= span.second; |
| 144 | } |
| 145 | }; |
| 146 | |
| 147 | // Pretty-printing function for Token. |
| 148 | inline logging::LoggingStringStream& operator<<( |
| 149 | logging::LoggingStringStream& stream, const Token& token) { |
| 150 | if (!token.is_padding) { |
| 151 | return stream << "Token(\"" << token.value << "\", " << token.start << ", " |
| 152 | << token.end << ")"; |
| 153 | } else { |
| 154 | return stream << "Token()"; |
| 155 | } |
| 156 | } |
| 157 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 158 | enum DatetimeGranularity { |
| 159 | GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this |
| 160 | // structure being uninitialized. |
| 161 | GRANULARITY_YEAR = 0, |
| 162 | GRANULARITY_MONTH = 1, |
| 163 | GRANULARITY_WEEK = 2, |
| 164 | GRANULARITY_DAY = 3, |
| 165 | GRANULARITY_HOUR = 4, |
| 166 | GRANULARITY_MINUTE = 5, |
| 167 | GRANULARITY_SECOND = 6 |
| 168 | }; |
| 169 | |
| 170 | struct DatetimeParseResult { |
| 171 | // The absolute time in milliseconds since the epoch in UTC. This is derived |
| 172 | // from the reference time and the fields specified in the text - so it may |
| 173 | // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm) |
| 174 | int64 time_ms_utc; |
| 175 | |
| 176 | // The precision of the estimate then in to calculating the milliseconds |
| 177 | DatetimeGranularity granularity; |
| 178 | |
| 179 | DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {} |
| 180 | |
| 181 | DatetimeParseResult(int64 arg_time_ms_utc, |
| 182 | DatetimeGranularity arg_granularity) |
| 183 | : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {} |
| 184 | |
| 185 | bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; } |
| 186 | |
| 187 | bool operator==(const DatetimeParseResult& other) const { |
| 188 | return granularity == other.granularity && time_ms_utc == other.time_ms_utc; |
| 189 | } |
| 190 | }; |
| 191 | |
| 192 | const float kFloatCompareEpsilon = 1e-5; |
| 193 | |
| 194 | struct DatetimeParseResultSpan { |
| 195 | CodepointSpan span; |
| 196 | DatetimeParseResult data; |
| 197 | float target_classification_score; |
| 198 | float priority_score; |
| 199 | |
| 200 | bool operator==(const DatetimeParseResultSpan& other) const { |
| 201 | return span == other.span && data.granularity == other.data.granularity && |
| 202 | data.time_ms_utc == other.data.time_ms_utc && |
| 203 | std::abs(target_classification_score - |
| 204 | other.target_classification_score) < kFloatCompareEpsilon && |
| 205 | std::abs(priority_score - other.priority_score) < |
| 206 | kFloatCompareEpsilon; |
| 207 | } |
| 208 | }; |
| 209 | |
| 210 | // Pretty-printing function for DatetimeParseResultSpan. |
| 211 | inline logging::LoggingStringStream& operator<<( |
| 212 | logging::LoggingStringStream& stream, |
| 213 | const DatetimeParseResultSpan& value) { |
| 214 | return stream << "DatetimeParseResultSpan({" << value.span.first << ", " |
| 215 | << value.span.second << "}, {/*time_ms_utc=*/ " |
| 216 | << value.data.time_ms_utc << ", /*granularity=*/ " |
| 217 | << value.data.granularity << "})"; |
| 218 | } |
| 219 | |
| 220 | struct ClassificationResult { |
| 221 | std::string collection; |
| 222 | float score; |
| 223 | DatetimeParseResult datetime_parse_result; |
| 224 | |
| 225 | // Internal score used for conflict resolution. |
| 226 | float priority_score; |
| 227 | |
| 228 | explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} |
| 229 | |
| 230 | ClassificationResult(const std::string& arg_collection, float arg_score) |
| 231 | : collection(arg_collection), |
| 232 | score(arg_score), |
| 233 | priority_score(arg_score) {} |
| 234 | |
| 235 | ClassificationResult(const std::string& arg_collection, float arg_score, |
| 236 | float arg_priority_score) |
| 237 | : collection(arg_collection), |
| 238 | score(arg_score), |
| 239 | priority_score(arg_priority_score) {} |
| 240 | }; |
| 241 | |
| 242 | // Pretty-printing function for ClassificationResult. |
| 243 | inline logging::LoggingStringStream& operator<<( |
| 244 | logging::LoggingStringStream& stream, const ClassificationResult& result) { |
| 245 | return stream << "ClassificationResult(" << result.collection << ", " |
| 246 | << result.score << ")"; |
| 247 | } |
| 248 | |
| 249 | // Pretty-printing function for std::vector<ClassificationResult>. |
| 250 | inline logging::LoggingStringStream& operator<<( |
| 251 | logging::LoggingStringStream& stream, |
| 252 | const std::vector<ClassificationResult>& results) { |
| 253 | stream = stream << "{\n"; |
| 254 | for (const ClassificationResult& result : results) { |
| 255 | stream = stream << " " << result << "\n"; |
| 256 | } |
| 257 | stream = stream << "}"; |
| 258 | return stream; |
| 259 | } |
| 260 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 261 | // Represents a result of Annotate call. |
| 262 | struct AnnotatedSpan { |
| 263 | // Unicode codepoint indices in the input string. |
| 264 | CodepointSpan span = {kInvalidIndex, kInvalidIndex}; |
| 265 | |
| 266 | // Classification result for the span. |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 267 | std::vector<ClassificationResult> classification; |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 268 | }; |
| 269 | |
| 270 | // Pretty-printing function for AnnotatedSpan. |
| 271 | inline logging::LoggingStringStream& operator<<( |
| 272 | logging::LoggingStringStream& stream, const AnnotatedSpan& span) { |
| 273 | std::string best_class; |
| 274 | float best_score = -1; |
| 275 | if (!span.classification.empty()) { |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 276 | best_class = span.classification[0].collection; |
| 277 | best_score = span.classification[0].score; |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 278 | } |
| 279 | return stream << "Span(" << span.span.first << ", " << span.span.second |
| 280 | << ", " << best_class << ", " << best_score << ")"; |
| 281 | } |
| 282 | |
| 283 | // StringPiece analogue for std::vector<T>. |
| 284 | template <class T> |
| 285 | class VectorSpan { |
| 286 | public: |
| 287 | VectorSpan() : begin_(), end_() {} |
| 288 | VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) |
| 289 | : begin_(v.begin()), end_(v.end()) {} |
| 290 | VectorSpan(typename std::vector<T>::const_iterator begin, |
| 291 | typename std::vector<T>::const_iterator end) |
| 292 | : begin_(begin), end_(end) {} |
| 293 | |
| 294 | const T& operator[](typename std::vector<T>::size_type i) const { |
| 295 | return *(begin_ + i); |
| 296 | } |
| 297 | |
| 298 | int size() const { return end_ - begin_; } |
| 299 | typename std::vector<T>::const_iterator begin() const { return begin_; } |
| 300 | typename std::vector<T>::const_iterator end() const { return end_; } |
| 301 | const float* data() const { return &(*begin_); } |
| 302 | |
| 303 | private: |
| 304 | typename std::vector<T>::const_iterator begin_; |
| 305 | typename std::vector<T>::const_iterator end_; |
| 306 | }; |
| 307 | |
Lukas Zilka | ba849e7 | 2018-03-08 14:48:21 +0100 | [diff] [blame] | 308 | struct DateParseData { |
| 309 | enum Relation { |
| 310 | NEXT = 1, |
| 311 | NEXT_OR_SAME = 2, |
| 312 | LAST = 3, |
| 313 | NOW = 4, |
| 314 | TOMORROW = 5, |
| 315 | YESTERDAY = 6, |
| 316 | PAST = 7, |
| 317 | FUTURE = 8 |
| 318 | }; |
| 319 | |
| 320 | enum RelationType { |
| 321 | MONDAY = 1, |
| 322 | TUESDAY = 2, |
| 323 | WEDNESDAY = 3, |
| 324 | THURSDAY = 4, |
| 325 | FRIDAY = 5, |
| 326 | SATURDAY = 6, |
| 327 | SUNDAY = 7, |
| 328 | DAY = 8, |
| 329 | WEEK = 9, |
| 330 | MONTH = 10, |
| 331 | YEAR = 11 |
| 332 | }; |
| 333 | |
| 334 | enum Fields { |
| 335 | YEAR_FIELD = 1 << 0, |
| 336 | MONTH_FIELD = 1 << 1, |
| 337 | DAY_FIELD = 1 << 2, |
| 338 | HOUR_FIELD = 1 << 3, |
| 339 | MINUTE_FIELD = 1 << 4, |
| 340 | SECOND_FIELD = 1 << 5, |
| 341 | AMPM_FIELD = 1 << 6, |
| 342 | ZONE_OFFSET_FIELD = 1 << 7, |
| 343 | DST_OFFSET_FIELD = 1 << 8, |
| 344 | RELATION_FIELD = 1 << 9, |
| 345 | RELATION_TYPE_FIELD = 1 << 10, |
| 346 | RELATION_DISTANCE_FIELD = 1 << 11 |
| 347 | }; |
| 348 | |
| 349 | enum AMPM { AM = 0, PM = 1 }; |
| 350 | |
| 351 | enum TimeUnit { |
| 352 | DAYS = 1, |
| 353 | WEEKS = 2, |
| 354 | MONTHS = 3, |
| 355 | HOURS = 4, |
| 356 | MINUTES = 5, |
| 357 | SECONDS = 6, |
| 358 | YEARS = 7 |
| 359 | }; |
| 360 | |
| 361 | // Bit mask of fields which have been set on the struct |
| 362 | int field_set_mask; |
| 363 | |
| 364 | // Fields describing absolute date fields. |
| 365 | // Year of the date seen in the text match. |
| 366 | int year; |
| 367 | // Month of the year starting with January = 1. |
| 368 | int month; |
| 369 | // Day of the month starting with 1. |
| 370 | int day_of_month; |
| 371 | // Hour of the day with a range of 0-23, |
| 372 | // values less than 12 need the AMPM field below or heuristics |
| 373 | // to definitively determine the time. |
| 374 | int hour; |
| 375 | // Hour of the day with a range of 0-59. |
| 376 | int minute; |
| 377 | // Hour of the day with a range of 0-59. |
| 378 | int second; |
| 379 | // 0 == AM, 1 == PM |
| 380 | int ampm; |
| 381 | // Number of hours offset from UTC this date time is in. |
| 382 | int zone_offset; |
| 383 | // Number of hours offest for DST |
| 384 | int dst_offset; |
| 385 | |
| 386 | // The permutation from now that was made to find the date time. |
| 387 | Relation relation; |
| 388 | // The unit of measure of the change to the date time. |
| 389 | RelationType relation_type; |
| 390 | // The number of units of change that were made. |
| 391 | int relation_distance; |
| 392 | }; |
| 393 | |
Lukas Zilka | 21d8c98 | 2018-01-24 11:11:20 +0100 | [diff] [blame] | 394 | } // namespace libtextclassifier2 |
| 395 | |
Lukas Zilka | b23e212 | 2018-02-09 10:25:19 +0100 | [diff] [blame] | 396 | #endif // LIBTEXTCLASSIFIER_TYPES_H_ |