blob: ebc9696acb49072e37ed480c99ea31f56ff2a8c1 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "tokenizer.h"
18
19#include <algorithm>
20
21#include "util/base/logging.h"
22#include "util/strings/utf8.h"
Lukas Zilka21d8c982018-01-24 11:11:20 +010023
24namespace libtextclassifier2 {
25
26Tokenizer::Tokenizer(
27 const std::vector<const TokenizationCodepointRange*>& codepoint_ranges,
28 bool split_on_script_change)
29 : codepoint_ranges_(codepoint_ranges),
30 split_on_script_change_(split_on_script_change) {
31 std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
32 [](const TokenizationCodepointRange* a,
33 const TokenizationCodepointRange* b) {
34 return a->start() < b->start();
35 });
36}
37
38const TokenizationCodepointRange* Tokenizer::FindTokenizationRange(
39 int codepoint) const {
40 auto it = std::lower_bound(
41 codepoint_ranges_.begin(), codepoint_ranges_.end(), codepoint,
42 [](const TokenizationCodepointRange* range, int codepoint) {
43 // This function compares range with the codepoint for the purpose of
44 // finding the first greater or equal range. Because of the use of
45 // std::lower_bound it needs to return true when range < codepoint;
46 // the first time it will return false the lower bound is found and
47 // returned.
48 //
49 // It might seem weird that the condition is range.end <= codepoint
50 // here but when codepoint == range.end it means it's actually just
51 // outside of the range, thus the range is less than the codepoint.
52 return range->end() <= codepoint;
53 });
54 if (it != codepoint_ranges_.end() && (*it)->start() <= codepoint &&
55 (*it)->end() > codepoint) {
56 return *it;
57 } else {
58 return nullptr;
59 }
60}
61
62void Tokenizer::GetScriptAndRole(char32 codepoint,
63 TokenizationCodepointRange_::Role* role,
64 int* script) const {
65 const TokenizationCodepointRange* range = FindTokenizationRange(codepoint);
66 if (range) {
67 *role = range->role();
68 *script = range->script_id();
69 } else {
70 *role = TokenizationCodepointRange_::Role_DEFAULT_ROLE;
71 *script = kUnknownScript;
72 }
73}
74
Lukas Zilkab23e2122018-02-09 10:25:19 +010075std::vector<Token> Tokenizer::Tokenize(const std::string& text) const {
76 UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
77 return Tokenize(text_unicode);
78}
Lukas Zilka21d8c982018-01-24 11:11:20 +010079
Lukas Zilkab23e2122018-02-09 10:25:19 +010080std::vector<Token> Tokenizer::Tokenize(const UnicodeText& text_unicode) const {
Lukas Zilka21d8c982018-01-24 11:11:20 +010081 std::vector<Token> result;
82 Token new_token("", 0, 0);
83 int codepoint_index = 0;
84
85 int last_script = kInvalidScript;
Lukas Zilkab23e2122018-02-09 10:25:19 +010086 for (auto it = text_unicode.begin(); it != text_unicode.end();
Lukas Zilka21d8c982018-01-24 11:11:20 +010087 ++it, ++codepoint_index) {
88 TokenizationCodepointRange_::Role role;
89 int script;
90 GetScriptAndRole(*it, &role, &script);
91
92 if (role & TokenizationCodepointRange_::Role_SPLIT_BEFORE ||
93 (split_on_script_change_ && last_script != kInvalidScript &&
94 last_script != script)) {
95 if (!new_token.value.empty()) {
96 result.push_back(new_token);
97 }
98 new_token = Token("", codepoint_index, codepoint_index);
99 }
100 if (!(role & TokenizationCodepointRange_::Role_DISCARD_CODEPOINT)) {
101 new_token.value += std::string(
102 it.utf8_data(),
103 it.utf8_data() + GetNumBytesForNonZeroUTF8Char(it.utf8_data()));
104 ++new_token.end;
105 }
106 if (role & TokenizationCodepointRange_::Role_SPLIT_AFTER) {
107 if (!new_token.value.empty()) {
108 result.push_back(new_token);
109 }
110 new_token = Token("", codepoint_index + 1, codepoint_index + 1);
111 }
112
113 last_script = script;
114 }
115 if (!new_token.value.empty()) {
116 result.push_back(new_token);
117 }
118
119 return result;
120}
121
122} // namespace libtextclassifier2